diff --git a/mars/dataframe/utils.py b/mars/dataframe/utils.py index 8958a0b485..f0ebfb3fba 100644 --- a/mars/dataframe/utils.py +++ b/mars/dataframe/utils.py @@ -33,10 +33,6 @@ from ..core import Entity, ExecutableTuple from ..core.context import Context, get_context from ..lib.mmh3 import hash as mmh_hash -from ..services.task.execution.ray.context import ( - RayExecutionContext, - RayExecutionWorkerContext, -) from ..tensor.utils import dictify_chunk_size, normalize_chunk_sizes from ..typing import ChunkType, TileableType from ..utils import ( @@ -46,7 +42,7 @@ ModulePlaceholder, is_full_slice, parse_readable_size, - is_ray_address, + is_on_ray, ) try: @@ -1437,6 +1433,9 @@ def _concat_chunks(merge_chunks: List[ChunkType], output_index: int): return new_op.new_tileable(df_or_series.op.inputs, kws=[params]) +# TODO: clean_up_func, is_on_ray and restore_func functions may be +# removed or refactored in the future to calculate func size +# with more accuracy as well as address some serialization issues. def clean_up_func(op): closure_clean_up_bytes_threshold = int( os.getenv("MARS_CLOSURE_CLEAN_UP_BYTES_THRESHOLD", 10**4) @@ -1461,7 +1460,7 @@ def clean_up_func(op): op.logic_key is not None ), f"Logic key of {op} wasn't calculated before cleaning up func." logger.debug(f"{op} need cleaning up func.") - if _is_on_ray(ctx): + if is_on_ray(ctx): import ray op.func_key = ray.put(op.func) @@ -1470,27 +1469,10 @@ def clean_up_func(op): op.func = cloudpickle.dumps(op.func) -def _is_on_ray(ctx): - # There are three conditions - # a. mars backend - # b. ray backend(oscar), c. ray backend(dag) - # When a. or b. is selected, ctx is an instance of ThreadedServiceContext. - # The main difference between them is whether worker_address matches ray scheme. - # To avoid duplicated checks, here we choose the first worker address. - # When c. is selected, ctx is an instance of RayExecutionContext or RayExecutionWorkerContext, - # while get_worker_addresses method isn't currently implemented in RayExecutionWorkerContext. - try: - worker_addresses = ctx.get_worker_addresses() - except AttributeError: # pragma: no cover - assert isinstance(ctx, RayExecutionWorkerContext) - return True - return isinstance(ctx, RayExecutionContext) or is_ray_address(worker_addresses[0]) - - def restore_func(ctx: Context, op): if op.need_clean_up_func and ctx is not None: logger.debug(f"{op} need restoring func.") - if _is_on_ray(ctx): + if is_on_ray(ctx): import ray op.func = ray.get(op.func_key) diff --git a/mars/deploy/oscar/tests/test_ray_dag_oscar.py b/mars/deploy/oscar/tests/test_ray_dag_oscar.py index 41d28c8c6c..dc658527b4 100644 --- a/mars/deploy/oscar/tests/test_ray_dag_oscar.py +++ b/mars/deploy/oscar/tests/test_ray_dag_oscar.py @@ -52,3 +52,23 @@ async def test_iterative_tiling(ray_start_regular_shared2, create_cluster): @require_ray async def test_execute_describe(ray_start_regular_shared2, create_cluster): await test_local.test_execute_describe(create_cluster) + + +@require_ray +@pytest.mark.parametrize( + "create_cluster", + [ + { + "config": { + "task.task_preprocessor_cls": "mars.deploy.oscar.tests.test_clean_up_and_restore_func.RayBackendFuncTaskPreprocessor", + "subtask.subtask_processor_cls": "mars.deploy.oscar.tests.test_clean_up_and_restore_func.RayBackendFuncSubtaskProcessor", + } + } + ], + indirect=True, +) +@pytest.mark.asyncio +async def test_ray_dag_oscar_clean_up_and_restore_func( + ray_start_regular_shared2, create_cluster +): + await test_local.test_execute_apply_closure(create_cluster) \ No newline at end of file diff --git a/mars/utils.py b/mars/utils.py index 3ce87e8523..a392f2575e 100644 --- a/mars/utils.py +++ b/mars/utils.py @@ -1701,6 +1701,30 @@ def is_ray_address(address: str) -> bool: return False +# TODO: clean_up_func, is_on_ray and restore_func functions may be +# removed or refactored in the future to calculate func size +# with more accuracy as well as address some serialization issues. +def is_on_ray(ctx): + from .services.task.execution.ray.context import ( + RayExecutionContext, + RayExecutionWorkerContext, + ) + # There are three conditions + # a. mars backend + # b. ray backend(oscar), c. ray backend(dag) + # When a. or b. is selected, ctx is an instance of ThreadedServiceContext. + # The main difference between them is whether worker address matches ray scheme. + # To avoid duplicated checks, here we choose the first worker address. + # When c. is selected, ctx is an instance of RayExecutionContext or RayExecutionWorkerContext, + # while get_worker_addresses method isn't currently implemented in RayExecutionWorkerContext. + try: + worker_addresses = ctx.get_worker_addresses() + except AttributeError: # pragma: no cover + assert isinstance(ctx, RayExecutionWorkerContext) + return True + return isinstance(ctx, RayExecutionContext) or is_ray_address(worker_addresses[0]) + + def cache_tileables(*tileables): from .core import ENTITY_TYPE