Skip to content

Commit

Permalink
fix: Use a -cuda image when requesting a GPU (#366)
Browse files Browse the repository at this point in the history
# The problem

Requesting a GPU on CE without specifying a custom image meant the
workflow would never run because there wouldn't be a node available with
GPU + CPU image.

# This PR's solution

Automatically adds the `-cuda` suffix to the default image when
requesting a GPU.

# Checklist

_Check that this PR satisfies the following items:_

- [x] Tests have been added for new features/changed behavior (if no new
features have been added, check the box).
- [x] The [changelog file](CHANGELOG.md) has been updated with a
user-readable description of the changes (if the change isn't visible to
the user in any way, check the box).
- [x] The PR's title is prefixed with
`<feat/fix/chore/imp[rovement]/int[ernal]/docs>[!]:`
- [x] The PR is linked to a JIRA ticket (if there's no suitable ticket,
check the box).
  • Loading branch information
jamesclark-Zapata authored Mar 4, 2024
1 parent b1b2e24 commit 17c9fdb
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 21 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

🐛 *Bug Fixes*

* Requesting GPUs with the default image will now use a GPU image on CE.

💅 *Improvements*

🥷 *Internal*
Expand Down
39 changes: 24 additions & 15 deletions src/orquestra/sdk/_ray/_build_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,13 @@
DEFAULT_IMAGE_TEMPLATE = "hub.nexus.orquestra.io/zapatacomputing/orquestra-sdk-base:{}"


def _get_default_image(template: str, sdk_version: str, num_gpus: t.Optional[int]):
image = template.format(sdk_version)
if num_gpus is not None and num_gpus > 0:
image = f"{image}-cuda"
return image


def _arg_from_graph(argument_id: ir.ArgumentId, workflow_def: ir.WorkflowDef):
try:
return workflow_def.constant_nodes[argument_id]
Expand Down Expand Up @@ -508,21 +515,6 @@ def make_ray_dag(
"max_retries": 0,
}

# Set custom image
if os.getenv(RAY_SET_CUSTOM_IMAGE_RESOURCES_ENV) is not None:
# This makes an assumption that only "new" IRs will get to this point
assert workflow_def.metadata is not None, "Expected a >=0.45.0 IR"
sdk_version = workflow_def.metadata.sdk_version.original

# Custom "Ray resources" request. The entries need to correspond to the ones
# used when starting the Ray cluster. See also:
# https://docs.ray.io/en/latest/ray-core/scheduling/resources.html#custom-resources
ray_options["resources"] = _ray_resources_for_custom_image(
invocation.custom_image
or user_task.custom_image
or DEFAULT_IMAGE_TEMPLATE.format(sdk_version)
)

# Non-custom task resources
if invocation.resources is not None:
if invocation.resources.cpu is not None:
Expand All @@ -540,6 +532,23 @@ def make_ray_dag(
gpu = int(float(invocation.resources.gpu))
ray_options["num_gpus"] = gpu

# Set custom image
if os.getenv(RAY_SET_CUSTOM_IMAGE_RESOURCES_ENV) is not None:
# This makes an assumption that only "new" IRs will get to this point
assert workflow_def.metadata is not None, "Expected a >=0.45.0 IR"
sdk_version = workflow_def.metadata.sdk_version.original

# Custom "Ray resources" request. The entries need to correspond to the ones
# used when starting the Ray cluster. See also:
# https://docs.ray.io/en/latest/ray-core/scheduling/resources.html#custom-resources
ray_options["resources"] = _ray_resources_for_custom_image(
invocation.custom_image
or user_task.custom_image
or _get_default_image(
DEFAULT_IMAGE_TEMPLATE, sdk_version, ray_options.get("num_gpus")
)
)

ray_result = _make_ray_dag_node(
client=client,
ray_options=ray_options,
Expand Down
34 changes: 28 additions & 6 deletions tests/runtime/ray/test_build_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
################################################################################

import re
from typing import Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Union
from unittest.mock import ANY, Mock, call, create_autospec

import pytest
Expand Down Expand Up @@ -221,14 +221,31 @@ def test_setting_resources(
assert isinstance(calls[0].kwargs[kwarg_name], type_)

@pytest.mark.parametrize(
"custom_image, expected_resources",
"custom_image, gpu, expected_resources, expected_kwargs",
(
("a_custom_image:latest", {"image:a_custom_image:latest": 1}),
(
"a_custom_image:latest",
None,
{"image:a_custom_image:latest": 1},
{},
),
(
None,
None,
{
"image:hub.nexus.orquestra.io/zapatacomputing/orquestra-sdk-base:mocked": 1 # noqa: E501
},
{},
),
(
None,
1,
{
"image:hub.nexus.orquestra.io/zapatacomputing/orquestra-sdk-base:mocked-cuda": 1 # noqa: E501
},
{
"num_gpus": 1,
},
),
),
)
Expand All @@ -239,12 +256,14 @@ def test_with_env_set(
wf_run_id: str,
monkeypatch: pytest.MonkeyPatch,
custom_image: Optional[str],
gpu: Optional[int],
expected_resources: Dict[str, int],
expected_kwargs: Dict[str, Any],
):
# Given
monkeypatch.setenv("ORQ_RAY_SET_CUSTOM_IMAGE_RESOURCES", "1")
workflow = workflow_parametrised_with_resources(
custom_image=custom_image
gpu=gpu, custom_image=custom_image
).model

# To prevent hardcoding a version number, let's override the version for
Expand All @@ -263,7 +282,6 @@ def test_with_env_set(
# We should only have two calls: our invocation and the aggregation step
assert len(calls) == 2
# Checking our call did not have any resources included

assert calls[0] == call(
ANY,
name=ANY,
Expand All @@ -272,18 +290,21 @@ def test_with_env_set(
catch_exceptions=ANY,
max_retries=ANY,
resources=expected_resources,
**expected_kwargs,
)

def test_with_env_not_set(
self,
client: Mock,
wf_run_id: str,
custom_image: Optional[str],
gpu: Optional[int],
expected_resources: Dict[str, int],
expected_kwargs: Dict[str, Any],
):
# Given
workflow = workflow_parametrised_with_resources(
custom_image=custom_image
gpu=gpu, custom_image=custom_image
).model

# When
Expand All @@ -302,6 +323,7 @@ def test_with_env_not_set(
runtime_env=ANY,
catch_exceptions=ANY,
max_retries=ANY,
**expected_kwargs,
)


Expand Down

0 comments on commit 17c9fdb

Please sign in to comment.