Skip to content

Commit

Permalink
perf
Browse files Browse the repository at this point in the history
  • Loading branch information
AndreyPavlenko committed Mar 20, 2024
1 parent ca80fd1 commit 8ce0b34
Show file tree
Hide file tree
Showing 4 changed files with 163 additions and 103 deletions.
111 changes: 36 additions & 75 deletions modin/core/execution/ray/common/deferred_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,24 +161,6 @@ def exec(
and self.flat_kwargs
and self.num_returns == 1
):
# self.data = RayWrapper.materialize(self.data)
# self.args = [
# RayWrapper.materialize(o) if isinstance(o, ray.ObjectRef) else o
# for o in self.args
# ]
# self.kwargs = {
# k: RayWrapper.materialize(o) if isinstance(o, ray.ObjectRef) else o
# for k, o in self.kwargs.items()
# }
# obj = _REMOTE_EXEC.exec_func(
# RayWrapper.materialize(self.func), self.data, self.args, self.kwargs
# )
# result, length, width, ip = (
# obj,
# len(obj) if hasattr(obj, "__len__") else 0,
# len(obj.columns) if hasattr(obj, "columns") else 0,
# "",
# )
result, length, width, ip = remote_exec_func.remote(
self.func, self.data, *self.args, **self.kwargs
)
Expand All @@ -191,13 +173,6 @@ def exec(
self.subscribers += 2
consumers, output = self._deconstruct()

# assert not any(isinstance(o, ListOrTuple) for o in output)
# tmp = [
# RayWrapper.materialize(o) if isinstance(o, ray.ObjectRef) else o
# for o in output
# ]
# list(_REMOTE_EXEC.construct(tmp))

# The last result is the MetaList, so adding +1 here.
num_returns = sum(c.num_returns for c in consumers) + 1
results = self._remote_exec_chain(num_returns, *output)
Expand Down Expand Up @@ -336,7 +311,9 @@ def _deconstruct_chain(
out_extend = output.extend
while True:
de.unsubscribe()
if not de.has_result and (out_pos := getattr(de, "out_pos", None)):
if not (has_result := de.has_result) and (
out_pos := getattr(de, "out_pos", None)
):
out_append(_Tag.REF)
out_append(out_pos)
output[out_pos] = out_pos
Expand All @@ -357,7 +334,7 @@ def _deconstruct_chain(
)
else:
out_append(data)
if not de.has_result:
if not has_result:
stack.append(de)
break
else:
Expand Down Expand Up @@ -425,28 +402,24 @@ def _deconstruct_list(
"""
for obj in lst:
if isinstance(obj, DeferredExecution):
if out_pos := getattr(obj, "out_pos", None):
if obj.has_result:
obj = obj.data
elif out_pos := getattr(obj, "out_pos", None):
obj.unsubscribe()
if obj.has_result:
if isinstance(obj.data, ListOrTuple):
out_append(_Tag.LIST)
yield cls._deconstruct_list(
obj.data, output, stack, result_consumers, out_append
)
else:
out_append(obj.data)
else:
out_append(_Tag.REF)
out_append(out_pos)
output[out_pos] = out_pos
if obj.subscribers == 0:
output[out_pos + 1] = 0
result_consumers.remove(obj)
out_append(_Tag.REF)
out_append(out_pos)
output[out_pos] = out_pos
if obj.subscribers == 0:
output[out_pos + 1] = 0
result_consumers.remove(obj)
continue

Check warning on line 415 in modin/core/execution/ray/common/deferred_execution.py

View check run for this annotation

Codecov / codecov/patch

modin/core/execution/ray/common/deferred_execution.py#L409-L415

Added lines #L409 - L415 were not covered by tests
else:
out_append(_Tag.CHAIN)
yield cls._deconstruct_chain(obj, output, stack, result_consumers)
out_append(_Tag.END)
elif isinstance(obj, ListOrTuple):
continue

if isinstance(obj, ListOrTuple):
out_append(_Tag.LIST)
yield cls._deconstruct_list(
obj, output, stack, result_consumers, out_append
Expand Down Expand Up @@ -517,27 +490,13 @@ class DeferredGetItem(DeferredExecution):
----------
data : ObjectRefOrDeType
The object to get the item from.
idx : int
index : int
The item index.
"""

def __init__(self, data: ObjectRefOrDeType, idx: int):
super().__init__(data, self._remote_fn(), [idx])
self.index = idx

@_inherit_docstrings(DeferredExecution.exec)
def exec(self) -> Tuple[ObjectRefType, "MetaList", int]:
if self.has_result:
return self.data, self.meta, self.meta_offset

if not isinstance(self.data, DeferredExecution) or self.data.num_returns == 1:
return super().exec()

# If `data` is a `DeferredExecution`, that returns multiple results,
# it's not required to execute `_remote_fn()`. We can only execute
# `data` and get the result by index.
self._data_exec()
return self.data, self.meta, self.meta_offset
def __init__(self, data: ObjectRefOrDeType, index: int):
super().__init__(data, self._remote_fn(), [index])
self.index = index

@property
@_inherit_docstrings(DeferredExecution.has_result)
Expand All @@ -550,16 +509,18 @@ def has_result(self):
and self.data.has_result
and self.data.num_returns != 1
):
self._data_exec()
# If `data` is a `DeferredExecution`, that returns multiple results,
# it's not required to execute `_remote_fn()`. We can only execute
# `data` and get the result by index.
self._set_result(
self.data.data[self.index],
self.data.meta,
self.data.meta_offset[self.index],
)
return True

return False

def _data_exec(self):
"""Execute the `data` task and get the result."""
obj, meta, offsets = self.data.exec()
self._set_result(obj[self.index], meta, offsets[self.index])

@classmethod
def _remote_fn(cls) -> ObjectRefType:
"""
Expand Down Expand Up @@ -592,7 +553,8 @@ def __init__(self, obj: Union[ray.ObjectID, ClientObjectRef, List]):

def materialize(self):
"""Materialized the list, if required."""
self._obj = RayWrapper.materialize(self._obj)
if not isinstance(self._obj, list):
self._obj = RayWrapper.materialize(self._obj)

def __getitem__(self, index):
"""
Expand Down Expand Up @@ -632,14 +594,13 @@ class MetaListHook(MaterializationHook, DeferredGetItem):
----------
meta : MetaList
Non-materialized list to get the value from.
idx : int
index : int
The value index in the list.
"""

def __init__(self, meta: MetaList, idx: int):
super().__init__(meta._obj, idx)
def __init__(self, meta: MetaList, index: int):
super().__init__(meta._obj, index)
self.meta = meta
self.idx = idx

def pre_materialize(self):
"""
Expand All @@ -650,7 +611,7 @@ def pre_materialize(self):
object
"""
obj = self.meta._obj
return obj[self.idx] if isinstance(obj, list) else obj
return obj[self.index] if isinstance(obj, list) else obj

def post_materialize(self, materialized):
"""
Expand All @@ -665,7 +626,7 @@ def post_materialize(self, materialized):
object
"""
self.meta._obj = materialized
return materialized[self.idx]
return materialized[self.index]


class _Tag(Enum): # noqa: PR01
Expand Down
2 changes: 1 addition & 1 deletion modin/core/execution/ray/common/engine_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def materialize(cls, obj_id):
Parameters
----------
obj_id : ray.ObjectID
obj_id : ObjectRefTypes
Ray object identifier to get the value by.
Returns
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,16 @@
# governing permissions and limitations under the License.

"""Module houses class that implements ``GenericRayDataframePartitionManager`` using Ray."""
import math

import numpy as np
import pandas
from pandas.core.dtypes.common import is_numeric_dtype

from modin.config import AsyncReadMode
from modin.config import AsyncReadMode, MinPartitionSize
from modin.core.execution.modin_aqp import progress_bar_wrapper
from modin.core.execution.ray.common import RayWrapper
from modin.core.execution.ray.common.deferred_execution import DeferredExecution
from modin.core.execution.ray.generic.partitioning import (
GenericRayDataframePartitionManager,
)
Expand All @@ -29,6 +32,7 @@
from .virtual_partition import (
PandasOnRayDataframeColumnPartition,
PandasOnRayDataframeRowPartition,
PandasOnRayDataframeVirtualPartition,
)


Expand All @@ -42,6 +46,69 @@ class PandasOnRayDataframePartitionManager(GenericRayDataframePartitionManager):
_execution_wrapper = RayWrapper
materialize_futures = RayWrapper.materialize

@classmethod
@_inherit_docstrings(GenericRayDataframePartitionManager.get_indices)
def get_indices(cls, axis, partitions, index_func=None):
partitions = partitions.T if axis == 0 else partitions
if len(partitions) == 0:
return pandas.Index([]), []

partitions = [part for part in partitions[0]]
non_split, lengths, _ = (
PandasOnRayDataframeVirtualPartition.find_non_split_block(partitions)
)
if non_split is not None:
partitions = [non_split]
else:
partitions = [part._data for part in partitions]

if index_func is None:
attr_name = f"_GET_AXIS_{axis}"
if (fn := getattr(cls, attr_name, None)) is None:

def get_cols(*dfs, axis=axis):
return [df.axes[axis] for df in dfs]

setattr(cls, attr_name, get_cols)
fn = RayWrapper.put(get_cols)
data, args = partitions[0], partitions[1:]
else:
if (fn := getattr(cls, "_GET_AXIS_FN", None)) is None:

Check warning on line 76 in modin/core/execution/ray/implementations/pandas_on_ray/partitioning/partition_manager.py

View check run for this annotation

Codecov / codecov/patch

modin/core/execution/ray/implementations/pandas_on_ray/partitioning/partition_manager.py#L76

Added line #L76 was not covered by tests

def apply_index(fn, *dfs):
return [fn(df) for df in dfs]

Check warning on line 79 in modin/core/execution/ray/implementations/pandas_on_ray/partitioning/partition_manager.py

View check run for this annotation

Codecov / codecov/patch

modin/core/execution/ray/implementations/pandas_on_ray/partitioning/partition_manager.py#L78-L79

Added lines #L78 - L79 were not covered by tests

cls._GET_AXIS_FN = fn = RayWrapper.put(apply_index)
data, args = index_func, partitions

Check warning on line 82 in modin/core/execution/ray/implementations/pandas_on_ray/partitioning/partition_manager.py

View check run for this annotation

Codecov / codecov/patch

modin/core/execution/ray/implementations/pandas_on_ray/partitioning/partition_manager.py#L81-L82

Added lines #L81 - L82 were not covered by tests

de = DeferredExecution(data, fn, args, num_returns=len(partitions))
part_indices = de.exec()[0]

if non_split is not None:
materialized = RayWrapper.materialize([part_indices] + lengths)
idx = materialized[0][0]
lengths = materialized[1:]
idx_len = len(idx)

if any(length is None for length in lengths) or idx_len != sum(lengths):
count = len(lengths)
chunk_len = max(math.ceil(idx_len / count), MinPartitionSize.get())
lengths = [chunk_len] * count

part_indices = []
start = 0
for length in lengths:
end = start + length
part_indices.append(idx[start:end])
start = end
return idx, part_indices

part_indices = RayWrapper.materialize(part_indices)
indices = [idx for idx in part_indices if len(idx)]
if len(indices) == 0:
return part_indices[0], part_indices
return indices[0].append(indices[1:]), part_indices

@classmethod
def wait_partitions(cls, partitions):
"""
Expand Down
Loading

0 comments on commit 8ce0b34

Please sign in to comment.