diff --git a/CHANGELOG.md b/CHANGELOG.md index 7ad099f4c..e7767eea9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,8 @@ 🐛 *Bug Fixes* +* Requesting GPUs with the default image will now use a GPU image on CE. + 💅 *Improvements* 🥷 *Internal* diff --git a/src/orquestra/sdk/_ray/_build_workflow.py b/src/orquestra/sdk/_ray/_build_workflow.py index 866712be9..7992110a7 100644 --- a/src/orquestra/sdk/_ray/_build_workflow.py +++ b/src/orquestra/sdk/_ray/_build_workflow.py @@ -32,6 +32,13 @@ DEFAULT_IMAGE_TEMPLATE = "hub.nexus.orquestra.io/zapatacomputing/orquestra-sdk-base:{}" +def _get_default_image(template: str, sdk_version: str, num_gpus: t.Optional[int]): + image = template.format(sdk_version) + if num_gpus is not None and num_gpus > 0: + image = f"{image}-cuda" + return image + + def _arg_from_graph(argument_id: ir.ArgumentId, workflow_def: ir.WorkflowDef): try: return workflow_def.constant_nodes[argument_id] @@ -508,21 +515,6 @@ def make_ray_dag( "max_retries": 0, } - # Set custom image - if os.getenv(RAY_SET_CUSTOM_IMAGE_RESOURCES_ENV) is not None: - # This makes an assumption that only "new" IRs will get to this point - assert workflow_def.metadata is not None, "Expected a >=0.45.0 IR" - sdk_version = workflow_def.metadata.sdk_version.original - - # Custom "Ray resources" request. The entries need to correspond to the ones - # used when starting the Ray cluster. See also: - # https://docs.ray.io/en/latest/ray-core/scheduling/resources.html#custom-resources - ray_options["resources"] = _ray_resources_for_custom_image( - invocation.custom_image - or user_task.custom_image - or DEFAULT_IMAGE_TEMPLATE.format(sdk_version) - ) - # Non-custom task resources if invocation.resources is not None: if invocation.resources.cpu is not None: @@ -540,6 +532,23 @@ def make_ray_dag( gpu = int(float(invocation.resources.gpu)) ray_options["num_gpus"] = gpu + # Set custom image + if os.getenv(RAY_SET_CUSTOM_IMAGE_RESOURCES_ENV) is not None: + # This makes an assumption that only "new" IRs will get to this point + assert workflow_def.metadata is not None, "Expected a >=0.45.0 IR" + sdk_version = workflow_def.metadata.sdk_version.original + + # Custom "Ray resources" request. The entries need to correspond to the ones + # used when starting the Ray cluster. See also: + # https://docs.ray.io/en/latest/ray-core/scheduling/resources.html#custom-resources + ray_options["resources"] = _ray_resources_for_custom_image( + invocation.custom_image + or user_task.custom_image + or _get_default_image( + DEFAULT_IMAGE_TEMPLATE, sdk_version, ray_options.get("num_gpus") + ) + ) + ray_result = _make_ray_dag_node( client=client, ray_options=ray_options, diff --git a/tests/runtime/ray/test_build_workflow.py b/tests/runtime/ray/test_build_workflow.py index c49c4a333..d7fcd3b93 100644 --- a/tests/runtime/ray/test_build_workflow.py +++ b/tests/runtime/ray/test_build_workflow.py @@ -3,7 +3,7 @@ ################################################################################ import re -from typing import Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union from unittest.mock import ANY, Mock, call, create_autospec import pytest @@ -221,14 +221,31 @@ def test_setting_resources( assert isinstance(calls[0].kwargs[kwarg_name], type_) @pytest.mark.parametrize( - "custom_image, expected_resources", + "custom_image, gpu, expected_resources, expected_kwargs", ( - ("a_custom_image:latest", {"image:a_custom_image:latest": 1}), ( + "a_custom_image:latest", + None, + {"image:a_custom_image:latest": 1}, + {}, + ), + ( + None, None, { "image:hub.nexus.orquestra.io/zapatacomputing/orquestra-sdk-base:mocked": 1 # noqa: E501 }, + {}, + ), + ( + None, + 1, + { + "image:hub.nexus.orquestra.io/zapatacomputing/orquestra-sdk-base:mocked-cuda": 1 # noqa: E501 + }, + { + "num_gpus": 1, + }, ), ), ) @@ -239,12 +256,14 @@ def test_with_env_set( wf_run_id: str, monkeypatch: pytest.MonkeyPatch, custom_image: Optional[str], + gpu: Optional[int], expected_resources: Dict[str, int], + expected_kwargs: Dict[str, Any], ): # Given monkeypatch.setenv("ORQ_RAY_SET_CUSTOM_IMAGE_RESOURCES", "1") workflow = workflow_parametrised_with_resources( - custom_image=custom_image + gpu=gpu, custom_image=custom_image ).model # To prevent hardcoding a version number, let's override the version for @@ -263,7 +282,6 @@ def test_with_env_set( # We should only have two calls: our invocation and the aggregation step assert len(calls) == 2 # Checking our call did not have any resources included - assert calls[0] == call( ANY, name=ANY, @@ -272,6 +290,7 @@ def test_with_env_set( catch_exceptions=ANY, max_retries=ANY, resources=expected_resources, + **expected_kwargs, ) def test_with_env_not_set( @@ -279,11 +298,13 @@ def test_with_env_not_set( client: Mock, wf_run_id: str, custom_image: Optional[str], + gpu: Optional[int], expected_resources: Dict[str, int], + expected_kwargs: Dict[str, Any], ): # Given workflow = workflow_parametrised_with_resources( - custom_image=custom_image + gpu=gpu, custom_image=custom_image ).model # When @@ -302,6 +323,7 @@ def test_with_env_not_set( runtime_env=ANY, catch_exceptions=ANY, max_retries=ANY, + **expected_kwargs, )