From f87718853f657144e24fe3d0c8188eacbe8ff3f0 Mon Sep 17 00:00:00 2001 From: James Clark <70290797+jamesclark-Zapata@users.noreply.github.com> Date: Tue, 5 Dec 2023 11:55:14 +0000 Subject: [PATCH] feat: Allow head node resources (#344) # The problem We want users to be able to select the head node resources on CE. # This PR's solution - adds a new kwarg to the workflow decorator - deprecates `data_aggregation` kwarg to the workflow decorator # Checklist _Check that this PR satisfies the following items:_ - [x] Tests have been added for new features/changed behavior (if no new features have been added, check the box). - [x] The [changelog file](CHANGELOG.md) has been updated with a user-readable description of the changes (if the change isn't visible to the user in any way, check the box). - [x] The PR's title is prefixed with `[!]:` - [x] The PR is linked to a JIRA ticket (if there's no suitable ticket, check the box). --------- Co-authored-by: Sebastian Morawiec --- CHANGELOG.md | 2 + .../sdk/_base/_driver/_ce_runtime.py | 25 ++++- src/orquestra/sdk/_base/_driver/_client.py | 3 + src/orquestra/sdk/_base/_driver/_models.py | 17 +++- src/orquestra/sdk/_base/_workflow.py | 24 ++++- src/orquestra/sdk/exceptions.py | 6 ++ tests/sdk/driver/test_ce_runtime.py | 92 ++++++++++++++++++- tests/sdk/driver/test_client.py | 29 ++++-- tests/sdk/test_traversal.py | 52 ++--------- tests/sdk/test_workflow.py | 72 +++++++-------- 10 files changed, 226 insertions(+), 96 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8ae15bece..a19d6b848 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,9 +7,11 @@ 🔥 *Features* * Listing workflow runs on CE now supports filtering by state and max age. +* `head_node_resources` keyword argument has been added to the `@sdk.workflow` decorator to control the resources that the head node has available on CE. 🧟 *Deprecations* * Deprecation of `project_dir` argument for all public API functions. +* `data_aggregation` keyword argument to `@sdk.workflow` decorator is deprecated and will be removed in a future version. 👩‍🔬 *Experimental* diff --git a/src/orquestra/sdk/_base/_driver/_ce_runtime.py b/src/orquestra/sdk/_base/_driver/_ce_runtime.py index e4ffc4866..1e181c4fc 100644 --- a/src/orquestra/sdk/_base/_driver/_ce_runtime.py +++ b/src/orquestra/sdk/_base/_driver/_ce_runtime.py @@ -2,6 +2,7 @@ # © Copyright 2022-2023 Zapata Computing Inc. ################################################################################ """RuntimeInterface implementation that uses Compute Engine.""" +import warnings from datetime import timedelta from pathlib import Path from typing import Dict, List, Optional, Protocol, Sequence, Union @@ -12,6 +13,7 @@ from orquestra.sdk._base._logs._interfaces import LogOutput, WorkflowLogs from orquestra.sdk._base._logs._models import LogAccumulator, LogStreamType from orquestra.sdk._base.abc import RuntimeInterface +from orquestra.sdk.exceptions import IgnoredFieldWarning from orquestra.sdk.kubernetes.quantity import parse_quantity from orquestra.sdk.schema.configs import RuntimeConfiguration from orquestra.sdk.schema.ir import ArtifactFormat, TaskInvocationId, WorkflowDef @@ -156,11 +158,32 @@ def create_workflow_run( _verify_workflow_resources(resources, max_invocation_resources) + if ( + workflow_def.data_aggregation is not None + and workflow_def.data_aggregation.resources is not None + ): + head_node_resources = _models.HeadNodeResources( + cpu=workflow_def.data_aggregation.resources.cpu, + memory=workflow_def.data_aggregation.resources.memory, + ) + if workflow_def.data_aggregation.resources.gpu: + warnings.warn( + "Head node resources will ignore GPU settings", + category=IgnoredFieldWarning, + ) + if workflow_def.data_aggregation.resources.nodes: + warnings.warn( + 'Head node resources will ignore "nodes" settings', + category=IgnoredFieldWarning, + ) + else: + head_node_resources = None + try: workflow_def_id = self._client.create_workflow_def(workflow_def, project) workflow_run_id = self._client.create_workflow_run( - workflow_def_id, resources, dry_run + workflow_def_id, resources, dry_run, head_node_resources ) except _exceptions.InvalidWorkflowDef as e: raise exceptions.WorkflowSyntaxError( diff --git a/src/orquestra/sdk/_base/_driver/_client.py b/src/orquestra/sdk/_base/_driver/_client.py index d13f6a78c..fb37c2e5b 100644 --- a/src/orquestra/sdk/_base/_driver/_client.py +++ b/src/orquestra/sdk/_base/_driver/_client.py @@ -471,6 +471,7 @@ def create_workflow_run( workflow_def_id: _models.WorkflowDefID, resources: _models.Resources, dry_run: bool, + head_node_resources: Optional[_models.HeadNodeResources], ) -> _models.WorkflowRunID: """Submit a workflow def to run in the workflow driver. @@ -478,6 +479,7 @@ def create_workflow_run( workflow_def_id: ID of the workflow definition to be submitted. resources: The resources required to execute the workflow. dry_run: Run the workflow without actually executing any task code. + head_node_resources: the requested resources for the head node Raises: orquestra.sdk._base._driver._exceptions.InvalidWorkflowRunRequest: when an @@ -497,6 +499,7 @@ def create_workflow_run( workflowDefinitionID=workflow_def_id, resources=resources, dryRun=dry_run, + headNodeResources=head_node_resources, ).dict(), ) diff --git a/src/orquestra/sdk/_base/_driver/_models.py b/src/orquestra/sdk/_base/_driver/_models.py index f4e55065d..6f227b67d 100644 --- a/src/orquestra/sdk/_base/_driver/_models.py +++ b/src/orquestra/sdk/_base/_driver/_models.py @@ -287,15 +287,30 @@ class Resources(pydantic.BaseModel): gpu: Optional[str] = pydantic.Field(regex="^[01]+$") +class HeadNodeResources(pydantic.BaseModel): + """ + Implements: + https://github.com/zapatacomputing/workflow-driver/blob/ac1e97ea00fc3526c93187a1da02170bff45b74f/openapi/src/schemas/HeadNodeResources.yaml. + """ # noqa: D205, D212 + + cpu: Optional[str] = pydantic.Field( + regex=r"^([+-]?[0-9.]+)([eEinumkKMGTP]*[-+]?[0-9]*)$" + ) + memory: Optional[str] = pydantic.Field( + regex=r"^([+-]?[0-9.]+)([eEinumkKMGTP]*[-+]?[0-9]*)$" + ) + + class CreateWorkflowRunRequest(pydantic.BaseModel): """ Implements: - https://github.com/zapatacomputing/workflow-driver/blob/be7f293b052a0fee7b012badebef72dee02a2ebe/openapi/src/schemas/CreateWorkflowRunRequest.yaml. + https://github.com/zapatacomputing/workflow-driver/blob/ac1e97ea00fc3526c93187a1da02170bff45b74f/openapi/src/schemas/CreateWorkflowRunRequest.yaml. """ # noqa: D205, D212 workflowDefinitionID: WorkflowDefID resources: Resources dryRun: bool + headNodeResources: Optional[Resources] class CreateWorkflowRunResponse(pydantic.BaseModel): diff --git a/src/orquestra/sdk/_base/_workflow.py b/src/orquestra/sdk/_base/_workflow.py index d85be95cb..4cab3d57e 100644 --- a/src/orquestra/sdk/_base/_workflow.py +++ b/src/orquestra/sdk/_base/_workflow.py @@ -539,6 +539,7 @@ def workflow(fn: Callable[_P, _R]) -> WorkflowTemplate[_P, _R]: def workflow( *, resources: Optional[_dsl.Resources] = None, + head_node_resources: Optional[_dsl.Resources] = None, data_aggregation: Optional[Union[DataAggregation, bool]] = None, custom_name: Optional[str] = None, default_source_import: Optional[Import] = None, @@ -551,6 +552,7 @@ def workflow( fn: Optional[Callable[_P, _R]] = None, *, resources: Optional[_dsl.Resources] = None, + head_node_resources: Optional[_dsl.Resources] = None, data_aggregation: Optional[Union[DataAggregation, bool]] = None, custom_name: Optional[str] = None, default_source_import: Optional[Import] = None, @@ -569,9 +571,10 @@ def workflow( 10 nodes with 20 CPUs and a GPU each would be: resources=sdk.Resources(cpu="20", gpu="1", nodes=10) If omitted, the cluster's default resources will be used. - data_aggregation: Used to set up resources used during data step. If skipped, - or assigned True default values will be used. If assigned False - data aggregation step will not run. + head_node_resources: !Unstable API! The resources that the head node requires. + Only used on Compute Engine. If omitted, the cluster's default head node + resources will be used. + data_aggregation: Deprecated. custom_name: custom name for the workflow default_source_import: Set the default source import for all tasks inside this workflow. @@ -622,6 +625,19 @@ def wf(): elif default_dependency_imports is not None: workflow_default_dependency_imports = tuple(default_dependency_imports) + if data_aggregation is not None: + warnings.warn( + "data_aggregation argument is deprecated and will be removed " + "in upcoming versions of orquestra-sdk. It will be ignored " + "for this workflow.", + category=FutureWarning, + ) + + if head_node_resources is not None: + _data_aggregation = DataAggregation(resources=head_node_resources) + else: + _data_aggregation = None + def _inner(fn: Callable[_P, _R]): signature = inspect.signature(fn) name = custom_name @@ -632,7 +648,7 @@ def _inner(fn: Callable[_P, _R]): workflow_fn=fn, fn_ref=fn_ref, is_parametrized=len(signature.parameters) > 0, - data_aggregation=data_aggregation, + data_aggregation=_data_aggregation, default_source_import=default_source_import, default_dependency_imports=workflow_default_dependency_imports, ) diff --git a/src/orquestra/sdk/exceptions.py b/src/orquestra/sdk/exceptions.py index 7660f1a91..7eeb6e96c 100644 --- a/src/orquestra/sdk/exceptions.py +++ b/src/orquestra/sdk/exceptions.py @@ -335,3 +335,9 @@ class WorkspacesNotSupportedError(BaseRuntimeError): """When a non-workspaces supporting runtime gets a workspaces-related request.""" pass + + +class IgnoredFieldWarning(Warning): + """Raised when a requested feature is not supported on the selected runtime.""" + + pass diff --git a/tests/sdk/driver/test_ce_runtime.py b/tests/sdk/driver/test_ce_runtime.py index 54e956fa7..cef94a6f9 100644 --- a/tests/sdk/driver/test_ce_runtime.py +++ b/tests/sdk/driver/test_ce_runtime.py @@ -12,6 +12,7 @@ from orquestra.sdk._base._driver import _ce_runtime, _client, _exceptions, _models from orquestra.sdk._base._spaces._structs import ProjectRef from orquestra.sdk._base._testing._example_wfs import ( + add, my_workflow, workflow_parametrised_with_resources, workflow_with_different_resources, @@ -94,6 +95,7 @@ def test_happy_path( workflow_def_id, _models.Resources(cpu=None, memory=None, gpu=None, nodes=None), False, + None, ) assert isinstance(wf_run_id, WorkflowRunId) assert ( @@ -124,6 +126,7 @@ def test_with_memory( workflow_def_id, _models.Resources(cpu=None, memory="10Gi", gpu=None, nodes=None), False, + None, ) def test_with_cpu( @@ -149,6 +152,7 @@ def test_with_cpu( workflow_def_id, _models.Resources(cpu="1000m", memory=None, gpu=None, nodes=None), False, + None, ) def test_with_gpu( @@ -172,6 +176,7 @@ def test_with_gpu( workflow_def_id, _models.Resources(cpu=None, memory=None, gpu="1", nodes=None), False, + None, ) def test_maximum_resource( @@ -195,6 +200,7 @@ def test_maximum_resource( workflow_def_id, _models.Resources(cpu="5000m", memory="3G", gpu="1", nodes=None), False, + None, ) def test_resources_from_workflow( @@ -222,8 +228,89 @@ def test_resources_from_workflow( workflow_def_id, _models.Resources(cpu="1", memory="1.5G", gpu="1", nodes=20), False, + None, ) + @pytest.mark.parametrize( + "head_node_resources", + ( + sdk.Resources(cpu=None, memory=None), + None, + ), + ) + def test_with_no_head_node_resources( + self, + mocked_client: MagicMock, + runtime: _ce_runtime.CERuntime, + workflow_def_id: str, + workflow_run_id: str, + head_node_resources: sdk.Resources, + ): + # Given + @sdk.workflow( + head_node_resources=head_node_resources, + ) + def wf(): + return add(1, 2) + + mocked_client.create_workflow_def.return_value = workflow_def_id + mocked_client.create_workflow_run.return_value = workflow_run_id + + # When + _ = runtime.create_workflow_run( + wf().model, + None, + dry_run=False, + ) + + # Then + mocked_client.create_workflow_run.assert_called_once_with( + workflow_def_id, + _models.Resources(cpu=None, memory=None, nodes=None, gpu=None), + False, + None, + ) + + @pytest.mark.parametrize( + "cpu, memory,", + ( + ("1", None), + (None, "10Gi"), + ("2", "20Gi"), + ), + ) + def test_with_head_node_resources( + self, + mocked_client: MagicMock, + runtime: _ce_runtime.CERuntime, + workflow_def_id: str, + workflow_run_id: str, + cpu: str, + memory: str, + ): + # Given + @sdk.workflow(head_node_resources=sdk.Resources(cpu=cpu, memory=memory)) + def wf(): + return add(1, 2) + + mocked_client.create_workflow_def.return_value = workflow_def_id + mocked_client.create_workflow_run.return_value = workflow_run_id + + # When + _ = runtime.create_workflow_run( + wf().model, + None, + dry_run=False, + ) + + # Then + mocked_client.create_workflow_run.assert_called_once_with( + workflow_def_id, + _models.Resources(cpu=None, memory=None, nodes=None, gpu=None), + False, + _models.HeadNodeResources(cpu=cpu, memory=memory), + ) + class TestWorkflowDefFailure: def test_invalid_wf_def( self, mocked_client: MagicMock, runtime: _ce_runtime.CERuntime @@ -1988,5 +2075,8 @@ def wf(): assert all([telltale in str(exec_info) for telltale in telltales]) else: mocked_client.create_workflow_run.assert_called_once_with( - workflow_def_id, expected_resources, False + workflow_def_id, + expected_resources, + False, + None, ) diff --git a/tests/sdk/driver/test_client.py b/tests/sdk/driver/test_client.py index b108b567c..013c2438c 100644 --- a/tests/sdk/driver/test_client.py +++ b/tests/sdk/driver/test_client.py @@ -1223,7 +1223,10 @@ def test_invalid_request( with pytest.raises(_exceptions.InvalidWorkflowRunRequest): _ = client.create_workflow_run( - workflow_def_id, resources, dry_run=False + workflow_def_id, + resources, + dry_run=False, + head_node_resources=None, ) @staticmethod @@ -1247,7 +1250,9 @@ def test_invalid_sdk_version( ) with pytest.raises(_exceptions.UnsupportedSDKVersion) as exc_info: - _ = client.create_workflow_run(workflow_def_id, resources, False) + _ = client.create_workflow_run( + workflow_def_id, resources, False, head_node_resources=None + ) assert exc_info.value.submitted_version == submitted_version assert exc_info.value.supported_versions == ["0.2.0", "0.3.0"] @@ -1271,7 +1276,9 @@ def test_invalid_sdk_version_unparsed_fallback( ) with pytest.raises(_exceptions.UnsupportedSDKVersion) as exc_info: - _ = client.create_workflow_run(workflow_def_id, resources, False) + _ = client.create_workflow_run( + workflow_def_id, resources, False, head_node_resources=None + ) assert exc_info.value.submitted_version is None assert exc_info.value.supported_versions is None @@ -1293,7 +1300,9 @@ def test_sets_auth( ], ) - client.create_workflow_run(workflow_def_id, resources, False) + client.create_workflow_run( + workflow_def_id, resources, False, head_node_resources=None + ) # The assertion is done by mocked_responses @@ -1311,7 +1320,9 @@ def test_unauthorized( ) with pytest.raises(_exceptions.InvalidTokenError): - _ = client.create_workflow_run(workflow_def_id, resources, False) + _ = client.create_workflow_run( + workflow_def_id, resources, False, head_node_resources=None + ) @staticmethod def test_forbidden( @@ -1327,7 +1338,9 @@ def test_forbidden( ) with pytest.raises(_exceptions.ForbiddenError): - _ = client.create_workflow_run(workflow_def_id, resources, False) + _ = client.create_workflow_run( + workflow_def_id, resources, False, head_node_resources=None + ) @staticmethod def test_unknown_error( @@ -1343,7 +1356,9 @@ def test_unknown_error( ) with pytest.raises(_exceptions.UnknownHTTPError): - _ = client.create_workflow_run(workflow_def_id, resources, False) + _ = client.create_workflow_run( + workflow_def_id, resources, False, head_node_resources=None + ) class TestTerminate: @staticmethod diff --git a/tests/sdk/test_traversal.py b/tests/sdk/test_traversal.py index 17f7666b2..67c115a1a 100644 --- a/tests/sdk/test_traversal.py +++ b/tests/sdk/test_traversal.py @@ -323,27 +323,6 @@ def constant_collisions(): return [1.0, 1, True, simple_task(1)] -@_workflow.workflow( - data_aggregation=_dsl.DataAggregation(run=False, resources=_dsl.Resources(cpu="1")) -) -def workflow_with_data_aggregation(): - return [simple_task(1)] - - -@_workflow.workflow(data_aggregation=_dsl.DataAggregation(run=True)) -def workflow_with_data_aggregation_no_resources(): - return [simple_task(1)] - - -@_workflow.workflow( - data_aggregation=_dsl.DataAggregation( - run=True, resources=_dsl.Resources(cpu="1", gpu="123") - ) -) -def workflow_with_data_aggregation_set_gpu(): - return [simple_task(1)] - - @_dsl.task() def simple_task(a): return a @@ -466,29 +445,18 @@ def test_workflow_without_data_aggregation(self): wf = constant_return.model assert wf.data_aggregation is None + @pytest.mark.filterwarnings("ignore:data_aggregation") def test_workflow_with_data_aggregation(self): + @_workflow.workflow( + data_aggregation=_workflow.DataAggregation( + resources=_dsl.Resources(cpu="1") + ) + ) + def workflow_with_data_aggregation(): + return capitalize("hello there") + wf = workflow_with_data_aggregation.model - assert wf.data_aggregation is not None - assert wf.data_aggregation.resources is not None - assert wf.data_aggregation.run is False - assert wf.data_aggregation.resources.cpu == "1" - - def test_workflow_with_data_aggregation_no_resources(self): - wf = workflow_with_data_aggregation_no_resources.model - assert wf.data_aggregation is not None - assert wf.data_aggregation.resources is None - assert wf.data_aggregation.run is True - - def test_workflow_with_data_aggregation_set_gpu(self): - # setting decorator explicitly to check for warnings - with pytest.warns(Warning) as warns: - wf = workflow_with_data_aggregation_set_gpu.model - assert len(warns.list) == 1 - assert wf.data_aggregation is not None - assert wf.data_aggregation.resources is not None - assert wf.data_aggregation.run is True - assert wf.data_aggregation.resources.gpu == "0" - assert wf.data_aggregation.resources.cpu == "1" + assert wf.data_aggregation is None def test_large_workflow(self): wf = large_workflow.model diff --git a/tests/sdk/test_workflow.py b/tests/sdk/test_workflow.py index 28d87d16c..9c6a5d6ea 100644 --- a/tests/sdk/test_workflow.py +++ b/tests/sdk/test_workflow.py @@ -77,51 +77,43 @@ def test_workflow_with_fake_imported_task(): _ = faked_task_wf() -class TestDataAggregationResources: - @staticmethod - @sdk.workflow( - data_aggregation=sdk.DataAggregation(resources=sdk.Resources(gpu="1g")) - ) - def _workflow_gpu_set_for_data_aggregation(): - return [_an_empty_task()] +@pytest.mark.parametrize( + "data_agg", + ( + True, + False, + sdk.DataAggregation(), + sdk.DataAggregation(resources=sdk.Resources(cpu="1")), + ), +) +def test_workflow_with_data_aggregation(data_agg): + with pytest.warns(Warning) as warns: - @staticmethod - @sdk.workflow(data_aggregation=False) - def _workflow_data_aggregation_false(): - return [_an_empty_task()] + @sdk.workflow( + data_aggregation=data_agg, + ) + def _(): + return [_an_empty_task()] - @staticmethod - @sdk.workflow(data_aggregation=True) - def _workflow_data_aggregation_true(): - return [_an_empty_task()] + assert len(warns.list) == 1 - @staticmethod - @sdk.workflow - def _workflow_no_data_aggregation(): + +@pytest.mark.parametrize( + "head_node_resources, expected_data_agg", + ( + (sdk.Resources(), sdk.DataAggregation(resources=sdk.Resources())), + (None, None), + (sdk.Resources(cpu="1"), sdk.DataAggregation(resources=sdk.Resources(cpu="1"))), + ), +) +def test_workflow_with_head_node_resources(head_node_resources, expected_data_agg): + @sdk.workflow( + head_node_resources=head_node_resources, + ) + def wf(): return [_an_empty_task()] - def test_workflow_with_gpu_set(self): - with pytest.warns(Warning) as warns: - wf = self._workflow_gpu_set_for_data_aggregation() - assert len(warns.list) == 1 - assert wf.model.data_aggregation is not None - assert wf.model.data_aggregation.resources is not None - assert wf.model.data_aggregation.resources.gpu == "0" - - def test_workflow_data_aggregation_false(self): - wf = self._workflow_data_aggregation_false() - assert wf.model.data_aggregation is not None - assert wf.model.data_aggregation.run is False - - def test_workflow_data_aggregation_true(self): - # verify that if user passes dataAggregation=True, it falls back to default - # as if nothing was provided - wf_data_aggregation_true = self._workflow_data_aggregation_true() - wf_data_aggregation_default = self._workflow_no_data_aggregation() - assert ( - wf_data_aggregation_true.model.data_aggregation - == wf_data_aggregation_default.data_aggregation - ) + assert wf._data_aggregation == expected_data_agg class TestModelsSerializeProperly: