Skip to content

Commit

Permalink
Feat: Allow to set max retries on task (#367)
Browse files Browse the repository at this point in the history
# The problem
Given that there is memory leak in the task that we cant fix (3rd party
lib) - there is currently no way to force ray to not reuse workers.
https://zapatacomputing.atlassian.net/browse/ORQSDK-1023

# This PR's solution
Instead of limiting worker-usage, users can set retries for given task.
Using that option, if the worker is OOMKilled due to memory leaks from
previous task calls - it can reboot itself with fresh memory and
continue as if nothing happened.

# Checklist

_Check that this PR satisfies the following items:_

- [x] Tests have been added for new features/changed behavior (if no new
features have been added, check the box).
- [x] The [changelog file](CHANGELOG.md) has been updated with a
user-readable description of the changes (if the change isn't visible to
the user in any way, check the box).
- [x] The PR's title is prefixed with
`<feat/fix/chore/imp[rovement]/int[ernal]/docs>[!]:`
- [x] The PR is linked to a JIRA ticket (if there's no suitable ticket,
check the box).
  • Loading branch information
SebastianMorawiec authored Mar 7, 2024
1 parent 8cce024 commit fad1f84
Show file tree
Hide file tree
Showing 8 changed files with 123 additions and 4 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

🔥 *Features*

* Added `max_retries` in `sdk.task` decorator. This allows users to restart ray workers on system crashes (like OOMKills or sigterms). Restarts do not happen with Python exceptions.

🧟 *Deprecations*

👩‍🔬 *Experimental*
Expand Down
12 changes: 12 additions & 0 deletions src/orquestra/sdk/_base/_dsl.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,7 @@ def __init__(
custom_image: Optional[str] = None,
custom_name: Optional[str] = None,
fn_ref: Optional[FunctionRef] = None,
max_retries: Optional[int] = None,
):
if isinstance(fn, BuiltinFunctionType):
raise NotImplementedError("Built-in functions are not supported as Tasks")
Expand All @@ -531,6 +532,7 @@ def __init__(
self._use_default_dependency_imports = dependency_imports is None
self._source_import = source_import
self._use_default_source_import = source_import is None
self._max_retries = max_retries

# task itself is not part of any workflow yet. Don't pass wf defaults
self._resolve_task_source_data()
Expand Down Expand Up @@ -1093,6 +1095,7 @@ def task(
n_outputs: Optional[int] = None,
custom_image: Optional[str] = None,
custom_name: Optional[str] = None,
max_retries: Optional[int] = None,
) -> Callable[[Callable[_P, _R]], TaskDef[_P, _R]]:
...

Expand All @@ -1107,6 +1110,7 @@ def task(
n_outputs: Optional[int] = None,
custom_image: Optional[str] = None,
custom_name: Optional[str] = None,
max_retries: Optional[int] = None,
) -> TaskDef[_P, _R]:
...

Expand All @@ -1120,6 +1124,7 @@ def task(
n_outputs: Optional[int] = None,
custom_image: Optional[str] = None,
custom_name: Optional[str] = None,
max_retries: Optional[int] = None,
) -> Union[TaskDef[_P, _R], Callable[[Callable[_P, _R]], TaskDef[_P, _R]]]:
"""Wraps a function into an Orquestra Task.
Expand Down Expand Up @@ -1151,6 +1156,12 @@ def task(
result of other task) - it will be placeholded. Every character that is
non-alphanumeric will be changed to dash ("-").
Also only first 128 characters of the name will be used
max_retries: Maximum number of times a worker will try to retry after failure.
Useful if worker is killed by random events, or memory leaks from previously
executed tasks.
WARNING: retried workers might cause issues in MLflow logging, as retried
workers share the same invocation ID, MLflow identifier will be shared
between them.
Raises:
ValueError: when a task has fewer than 1 outputs.
Expand Down Expand Up @@ -1188,6 +1199,7 @@ def _inner(fn: Callable[_P, _R]):
output_metadata=output_metadata,
custom_image=custom_image,
custom_name=custom_name,
max_retries=max_retries,
)

return task_def
Expand Down
1 change: 1 addition & 0 deletions src/orquestra/sdk/_base/_traversal.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,6 +590,7 @@ def _make_task_model(
resources=resources,
parameters=parameters,
custom_image=task._custom_image,
max_retries=task._max_retries,
)


Expand Down
11 changes: 7 additions & 4 deletions src/orquestra/sdk/_ray/_build_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,13 +506,16 @@ def make_ray_dag(
# If there are any python packages to install for step - set runtime env
"runtime_env": (_client.RuntimeEnv(pip=pip) if len(pip) > 0 else None),
"catch_exceptions": False,
# We only want to execute workflow tasks once. This is so there is only one
# task run ID per task, for scenarios where this is used (like in MLFlow).
# We only want to execute workflow tasks once by default.
# This is so there is only one task run ID per task, for scenarios where
# this is used (like in MLflow). We allow setting this variable on
# task-level for some particular edge-cases like memory leaks inside
# 3rd party libraries - so in case of the OOMKilled worker it can be
# restarted.
# By default, Ray will only retry tasks that fail due to a "system error".
# For example, if the worker process crashes or exits early.
# Normal Python exceptions are NOT retried.
# So, we turn max_retries down to 0.
"max_retries": 0,
"max_retries": user_task.max_retries if user_task.max_retries else 0,
}

# Non-custom task resources
Expand Down
1 change: 1 addition & 0 deletions src/orquestra/sdk/schema/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,7 @@ class TaskDef(BaseModel):

resources: t.Optional[Resources] = None

max_retries: t.Optional[int] = None
# Hints the runtime to run this task in a docker container with this image. Has no
# effect if the runtime doesn't support it.
custom_image: t.Optional[str] = None
Expand Down
72 changes: 72 additions & 0 deletions tests/runtime/ray/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -1340,3 +1340,75 @@ def wf():
# Precondition
wf_run = runtime.get_workflow_run_status(wf_run_id)
assert wf_run.status.state == State.SUCCEEDED


@pytest.mark.slow
class TestRetries:
"""
Test that retrying Ray Workers works properly
"""

@pytest.mark.parametrize(
"max_retries,should_fail",
[
(1, False), # we should not fail with max_retries enabled
(50, False), # we should not fail with max_retries enabled
(0, True), # 0 means do not retry
(None, True), # We do not enable max_retries by default
],
)
def test_max_retries(self, runtime: _dag.RayRuntime, max_retries, should_fail):
@sdk.task(max_retries=max_retries)
def generic_task(*args):
if hasattr(sdk, "l"):
sdk.l.extend([0]) # type: ignore # noqa
else:
setattr(sdk, "l", [0])
if len(sdk.l) == 2: # type: ignore # noqa
import os
import signal

os.kill(os.getpid(), signal.SIGTERM)

return None

@sdk.workflow
def wf():
task_res = None
for _ in range(5):
task_res = generic_task(task_res)
return task_res

wf_model = wf().model

# When
# The function-under-test is called inside the workflow.
wf_run_id = runtime.create_workflow_run(wf_model, project=None, dry_run=False)

# we can't base our logic on SDK workflow status because of:
# https://zapatacomputing.atlassian.net/browse/ORQSDK-1024
# We can just look into the message at peek that the workflow actually failed
# even tho we report is as RUNNING.
import ray.workflow
from ray.workflow.common import WorkflowStatus

no_of_retries = 0

while True:
ray_status = ray.workflow.get_status(wf_run_id)
if no_of_retries >= 30:
break
if ray_status == WorkflowStatus.RUNNING:
time.sleep(1)
no_of_retries += 1
continue
if (
ray_status == WorkflowStatus.FAILED
or ray_status == WorkflowStatus.SUCCESSFUL
):
break

if should_fail:
assert ray_status == WorkflowStatus.FAILED
else:
assert ray_status == WorkflowStatus.SUCCESSFUL
8 changes: 8 additions & 0 deletions tests/sdk/test_dsl.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,6 +545,14 @@ def _local_task_1(x):
assert len(warns.list) == 1


def test_max_retries():
@_dsl.task(max_retries=5)
def task():
...

assert task._max_retries == 5


def test_default_import_type(monkeypatch):
@_dsl.task
def task():
Expand Down
20 changes: 20 additions & 0 deletions tests/sdk/test_traversal.py
Original file line number Diff line number Diff line change
Expand Up @@ -974,6 +974,26 @@ def workflow():
assert list(workflow.model.task_invocations.keys())[0] == expected


@pytest.mark.parametrize(
"argument, expected",
[
(None, None),
(1, 1),
(999, 999),
],
)
def test_max_calls(argument, expected):
@_dsl.task(max_retries=argument)
def task():
...

@_workflow.workflow()
def workflow():
return task()

assert list(workflow.model.tasks.values())[0].max_retries == expected


class TestNumberOfFetchesOnInferRepos:
@pytest.fixture()
def setup_fetch(self, monkeypatch):
Expand Down

0 comments on commit fad1f84

Please sign in to comment.