diff --git a/modin/core/execution/ray/common/deferred_execution.py b/modin/core/execution/ray/common/deferred_execution.py index 0ace6c27b10..db90130bafa 100644 --- a/modin/core/execution/ray/common/deferred_execution.py +++ b/modin/core/execution/ray/common/deferred_execution.py @@ -34,8 +34,9 @@ from modin.core.execution.ray.common import MaterializationHook, RayWrapper from modin.logging import get_logger +from modin.utils import _inherit_docstrings -ObjectRefType = Union[ray.ObjectRef, ClientObjectRef, None] +ObjectRefType = Union[ray.ObjectRef, ClientObjectRef] ObjectRefOrListType = Union[ObjectRefType, List[ObjectRefType]] ListOrTuple = (list, tuple) @@ -68,16 +69,18 @@ class DeferredExecution: Attributes ---------- - data : ObjectRefType or DeferredExecution + data : object The execution input. func : callable or ObjectRefType A function to be executed. - args : list or tuple + args : list or tuple, optional Additional positional arguments to be passed in `func`. - kwargs : dict + kwargs : dict, optional Additional keyword arguments to be passed in `func`. - num_returns : int + num_returns : int, default: 1 The number of the return values. + flat_data : bool + True means that the data is neither DeferredExecution nor list. flat_args : bool True means that there are no lists or DeferredExecution objects in `args`. In this case, no arguments processing is performed and `args` is passed @@ -88,26 +91,29 @@ class DeferredExecution: def __init__( self, - data: Union[ - ObjectRefType, - "DeferredExecution", - List[Union[ObjectRefType, "DeferredExecution"]], - ], + data: Any, func: Union[Callable, ObjectRefType], - args: Union[List[Any], Tuple[Any]], - kwargs: Dict[str, Any], + args: Union[List[Any], Tuple[Any]] = None, + kwargs: Dict[str, Any] = None, num_returns=1, ): - if isinstance(data, DeferredExecution): - data.subscribe() + self.flat_data = self._flat_args((data,)) self.data = data self.func = func - self.args = args - self.kwargs = kwargs self.num_returns = num_returns - self.flat_args = self._flat_args(args) - self.flat_kwargs = self._flat_args(kwargs.values()) self.subscribers = 0 + if args is not None: + self.args = args + self.flat_args = self._flat_args(args) + else: + self.args = () + self.flat_args = True + if kwargs is not None: + self.kwargs = kwargs + self.flat_kwargs = self._flat_args(kwargs.values()) + else: + self.kwargs = {} + self.flat_kwargs = True @classmethod def _flat_args(cls, args: Iterable): @@ -134,7 +140,7 @@ def _flat_args(cls, args: Iterable): def exec( self, - ) -> Tuple[ObjectRefOrListType, Union["MetaList", List], Union[int, List[int]]]: + ) -> Tuple[ObjectRefOrListType, "MetaList", Union[int, List[int]]]: """ Execute this task, if required. @@ -150,11 +156,29 @@ def exec( return self.data, self.meta, self.meta_offset if ( - not isinstance(self.data, DeferredExecution) + self.flat_data and self.flat_args and self.flat_kwargs and self.num_returns == 1 ): + # self.data = RayWrapper.materialize(self.data) + # self.args = [ + # RayWrapper.materialize(o) if isinstance(o, ray.ObjectRef) else o + # for o in self.args + # ] + # self.kwargs = { + # k: RayWrapper.materialize(o) if isinstance(o, ray.ObjectRef) else o + # for k, o in self.kwargs.items() + # } + # obj = _REMOTE_EXEC.exec_func( + # RayWrapper.materialize(self.func), self.data, self.args, self.kwargs + # ) + # result, length, width, ip = ( + # obj, + # len(obj) if hasattr(obj, "__len__") else 0, + # len(obj.columns) if hasattr(obj, "columns") else 0, + # "", + # ) result, length, width, ip = remote_exec_func.remote( self.func, self.data, *self.args, **self.kwargs ) @@ -166,6 +190,14 @@ def exec( # it back. After the execution, the result is saved and the counter has no effect. self.subscribers += 2 consumers, output = self._deconstruct() + + # assert not any(isinstance(o, ListOrTuple) for o in output) + # tmp = [ + # RayWrapper.materialize(o) if isinstance(o, ray.ObjectRef) else o + # for o in output + # ] + # list(_REMOTE_EXEC.construct(tmp)) + # The last result is the MetaList, so adding +1 here. num_returns = sum(c.num_returns for c in consumers) + 1 results = self._remote_exec_chain(num_returns, *output) @@ -173,12 +205,13 @@ def exec( meta_offset = 0 results = iter(results) for de in consumers: - if de.num_returns == 1: + num_returns = de.num_returns + if num_returns == 1: de._set_result(next(results), meta, meta_offset) meta_offset += 2 else: res = list(islice(results, num_returns)) - offsets = list(range(0, 2 * num_returns, 2)) + offsets = list(range(meta_offset, meta_offset + 2 * num_returns, 2)) de._set_result(res, meta, offsets) meta_offset += 2 * num_returns return self.data, self.meta, self.meta_offset @@ -318,6 +351,7 @@ def _deconstruct_chain( break elif not isinstance(data := de.data, DeferredExecution): if isinstance(data, ListOrTuple): + out_append(_Tag.LIST) yield cls._deconstruct_list( data, output, stack, result_consumers, out_append ) @@ -394,7 +428,13 @@ def _deconstruct_list( if out_pos := getattr(obj, "out_pos", None): obj.unsubscribe() if obj.has_result: - out_append(obj.data) + if isinstance(obj.data, ListOrTuple): + out_append(_Tag.LIST) + yield cls._deconstruct_list( + obj.data, output, stack, result_consumers, out_append + ) + else: + out_append(obj.data) else: out_append(_Tag.REF) out_append(out_pos) @@ -432,13 +472,13 @@ def _remote_exec_chain(num_returns: int, *args: Tuple) -> List[Any]: list The execution results. The last element of this list is the ``MetaList``. """ - # Prefer _remote_exec_single_chain(). It has fewer arguments and - # does not require the num_returns to be specified in options. + # Prefer _remote_exec_single_chain(). It does not require the num_returns + # to be specified in options. if num_returns == 2: return _remote_exec_single_chain.remote(*args) else: return _remote_exec_multi_chain.options(num_returns=num_returns).remote( - num_returns, *args + *args ) def _set_result( @@ -456,7 +496,7 @@ def _set_result( meta : MetaList meta_offset : int or list of int """ - del self.func, self.args, self.kwargs, self.flat_args, self.flat_kwargs + del self.func, self.args, self.kwargs self.data = result self.meta = meta self.meta_offset = meta_offset @@ -466,6 +506,78 @@ def __reduce__(self): raise NotImplementedError("DeferredExecution is not serializable!") +ObjectRefOrDeType = Union[ObjectRefType, DeferredExecution] + + +class DeferredGetItem(DeferredExecution): + """ + Deferred execution task that returns an item at the specified index. + + Parameters + ---------- + data : ObjectRefOrDeType + The object to get the item from. + idx : int + The item index. + """ + + def __init__(self, data: ObjectRefOrDeType, idx: int): + super().__init__(data, self._remote_fn(), [idx]) + self.index = idx + + @_inherit_docstrings(DeferredExecution.exec) + def exec(self) -> Tuple[ObjectRefType, "MetaList", int]: + if self.has_result: + return self.data, self.meta, self.meta_offset + + if not isinstance(self.data, DeferredExecution) or self.data.num_returns == 1: + return super().exec() + + # If `data` is a `DeferredExecution`, that returns multiple results, + # it's not required to execute `_remote_fn()`. We can only execute + # `data` and get the result by index. + self._data_exec() + return self.data, self.meta, self.meta_offset + + @property + @_inherit_docstrings(DeferredExecution.has_result) + def has_result(self): + if super().has_result: + return True + + if ( + isinstance(self.data, DeferredExecution) + and self.data.has_result + and self.data.num_returns != 1 + ): + self._data_exec() + return True + + return False + + def _data_exec(self): + """Execute the `data` task and get the result.""" + obj, meta, offsets = self.data.exec() + self._set_result(obj[self.index], meta, offsets[self.index]) + + @classmethod + def _remote_fn(cls) -> ObjectRefType: + """ + Return the remote function reference. + + Returns + ------- + ObjectRefType + """ + if (fn := getattr(cls, "_GET_ITEM", None)) is None: + + def get_item(obj, index): # pragma: no cover + return obj[index] + + cls._GET_ITEM = fn = RayWrapper.put(get_item) + return fn + + class MetaList: """ Meta information, containing the result lengths and the worker address. @@ -478,6 +590,10 @@ class MetaList: def __init__(self, obj: Union[ray.ObjectID, ClientObjectRef, List]): self._obj = obj + def materialize(self): + """Materialized the list, if required.""" + self._obj = RayWrapper.materialize(self._obj) + def __getitem__(self, index): """ Get item at the specified index. @@ -508,7 +624,7 @@ def __setitem__(self, index, value): obj[index] = value -class MetaListHook(MaterializationHook): +class MetaListHook(MaterializationHook, DeferredGetItem): """ Used by MetaList.__getitem__() for lazy materialization and getting a single value from the list. @@ -521,6 +637,7 @@ class MetaListHook(MaterializationHook): """ def __init__(self, meta: MetaList, idx: int): + super().__init__(meta._obj, idx) self.meta = meta self.idx = idx @@ -605,7 +722,7 @@ def exec_func(fn: Callable, obj: Any, args: Tuple, kwargs: Dict) -> Any: raise err @classmethod - def construct(cls, num_returns: int, args: Tuple): # pragma: no cover + def construct(cls, args: Tuple): # pragma: no cover """ Construct and execute the specified chain. @@ -615,7 +732,6 @@ def construct(cls, num_returns: int, args: Tuple): # pragma: no cover Parameters ---------- - num_returns : int args : tuple Yields @@ -687,7 +803,7 @@ def construct_chain( while chain: fn = pop() - if fn == tg_e: + if fn is tg_e: lst.append(obj) break @@ -717,10 +833,10 @@ def construct_chain( itr = iter([obj] if num_returns == 1 else obj) for _ in range(num_returns): - obj = next(itr) - meta.append(len(obj) if hasattr(obj, "__len__") else 0) - meta.append(len(obj.columns) if hasattr(obj, "columns") else 0) - yield obj + o = next(itr) + meta.append(len(o) if hasattr(o, "__len__") else 0) + meta.append(len(o.columns) if hasattr(o, "columns") else 0) + yield o @classmethod def construct_list( @@ -834,20 +950,18 @@ def _remote_exec_single_chain( ------- Generator """ - return remote_executor.construct(num_returns=2, args=args) + return remote_executor.construct(args=args) @ray.remote def _remote_exec_multi_chain( - num_returns: int, *args: Tuple, remote_executor=_REMOTE_EXEC + *args: Tuple, remote_executor=_REMOTE_EXEC ) -> Generator: # pragma: no cover """ Execute the deconstructed chain with a multiple return values in a worker process. Parameters ---------- - num_returns : int - The number of return values. *args : tuple A deconstructed chain to be executed. remote_executor : _RemoteExecutor, default: _REMOTE_EXEC @@ -857,4 +971,4 @@ def _remote_exec_multi_chain( ------- Generator """ - return remote_executor.construct(num_returns, args) + return remote_executor.construct(args) diff --git a/modin/core/execution/ray/common/engine_wrapper.py b/modin/core/execution/ray/common/engine_wrapper.py index fc5f8a643d2..9050e168ab4 100644 --- a/modin/core/execution/ray/common/engine_wrapper.py +++ b/modin/core/execution/ray/common/engine_wrapper.py @@ -20,7 +20,7 @@ import asyncio import os from types import FunctionType -from typing import Sequence +from typing import Iterable, Sequence import ray from ray.util.client.common import ClientObjectRef @@ -214,7 +214,7 @@ def wait(cls, obj_ids, num_returns=None): num_returns : int, optional """ if not isinstance(obj_ids, Sequence): - obj_ids = list(obj_ids) + obj_ids = list(obj_ids) if isinstance(obj_ids, Iterable) else [obj_ids] ids = set() for obj in obj_ids: diff --git a/modin/core/execution/ray/implementations/pandas_on_ray/partitioning/partition.py b/modin/core/execution/ray/implementations/pandas_on_ray/partitioning/partition.py index a4c35bf7e95..9c5e3a43c8f 100644 --- a/modin/core/execution/ray/implementations/pandas_on_ray/partitioning/partition.py +++ b/modin/core/execution/ray/implementations/pandas_on_ray/partitioning/partition.py @@ -148,7 +148,7 @@ def add_to_apply_calls( def drain_call_queue(self): data = self._data_ref if not isinstance(data, DeferredExecution): - return data + return log = get_logger() self._is_debug(log) and log.debug( @@ -419,7 +419,7 @@ def eager_exec(self, func, *args, length=None, width=None, **kwargs): LazyExecution.subscribe(_configure_lazy_exec) -class SlicerHook(MaterializationHook): +class SlicerHook(MaterializationHook, DeferredExecution): """ Used by mask() for the slilced length computation. @@ -432,6 +432,7 @@ class SlicerHook(MaterializationHook): """ def __init__(self, ref: ObjectIDType, slc: slice): + super().__init__(slc, compute_sliced_len, [ref]) self.ref = ref self.slc = slc diff --git a/modin/core/execution/ray/implementations/pandas_on_ray/partitioning/virtual_partition.py b/modin/core/execution/ray/implementations/pandas_on_ray/partitioning/virtual_partition.py index f30b474305a..917ba989e65 100644 --- a/modin/core/execution/ray/implementations/pandas_on_ray/partitioning/virtual_partition.py +++ b/modin/core/execution/ray/implementations/pandas_on_ray/partitioning/virtual_partition.py @@ -12,248 +12,443 @@ # governing permissions and limitations under the License. """Module houses classes responsible for storing a virtual partition and applying a function to it.""" +import math +from typing import Collection, List, Optional, Set, Union import pandas import ray -from ray.util import get_node_ip_address +from modin.config import MinPartitionSize +from modin.core.dataframe.base.partitioning.axis_partition import ( + BaseDataframeAxisPartition, +) from modin.core.dataframe.pandas.partitioning.axis_partition import ( PandasDataframeAxisPartition, ) from modin.core.execution.ray.common import RayWrapper +from modin.core.execution.ray.common.deferred_execution import ( + DeferredExecution, + DeferredGetItem, + MetaList, + ObjectRefOrDeType, + ObjectRefType, +) from modin.utils import _inherit_docstrings from .partition import PandasOnRayDataframePartition -class PandasOnRayDataframeVirtualPartition(PandasDataframeAxisPartition): +class PandasOnRayDataframeVirtualPartition(BaseDataframeAxisPartition): """ The class implements the interface in ``PandasDataframeAxisPartition``. Parameters ---------- - list_of_partitions : Union[list, PandasOnRayDataframePartition] - List of ``PandasOnRayDataframePartition`` and - ``PandasOnRayDataframeVirtualPartition`` objects, or a single - ``PandasOnRayDataframePartition``. - get_ip : bool, default: False - Whether to get node IP addresses to conforming partitions or not. + data : DeferredExecution or list of PandasOnRayDataframePartition full_axis : bool, default: True Whether or not the virtual partition encompasses the whole axis. - call_queue : list, optional - A list of tuples (callable, args, kwargs) that contains deferred calls. length : ray.ObjectRef or int, optional Length, or reference to length, of wrapped ``pandas.DataFrame``. width : ray.ObjectRef or int, optional Width, or reference to width, of wrapped ``pandas.DataFrame``. + num_splits : int, optional + The number of chunks to split the results on. + chunk_lengths : list of ints, optional + The chunk lengths. """ - _PARTITIONS_METADATA_LEN = 3 # (length, width, ip) partition_type = PandasOnRayDataframePartition instance_type = ray.ObjectRef axis = None - # these variables are intentionally initialized at runtime (see #6023) - _DEPLOY_AXIS_FUNC = None - _DEPLOY_SPLIT_FUNC = None - _DRAIN_FUNC = None - - @classmethod - def _get_deploy_axis_func(cls): # noqa: GL08 - if cls._DEPLOY_AXIS_FUNC is None: - cls._DEPLOY_AXIS_FUNC = RayWrapper.put( - PandasDataframeAxisPartition.deploy_axis_func + def __init__( + self, + data: Union[ + DeferredExecution, + PandasOnRayDataframePartition, + List[PandasOnRayDataframePartition], + ], + full_axis: bool = True, + length: Union[int, ObjectRefType] = None, + width: Union[int, ObjectRefType] = None, + num_splits=None, + chunk_lengths=None, + ): + self.full_axis = full_axis + self._meta = MetaList([length, width, None]) + self._meta_offset = 0 + self._chunk_lengths_cache = chunk_lengths + + if isinstance(data, DeferredExecution): + self._set_data_ref(data) + self._num_splits = num_splits + self._list_of_block_partitions = None + return + + if not isinstance(data, Collection) or len(data) == 1: + if not isinstance(data, Collection): + data = [data] + self._set_data_ref(data[0]._data_ref) + self._num_splits = 1 + self._list_of_block_partitions = data + return + + self._num_splits = len(data) + self._list_of_block_partitions = data + refs = [part._data_ref for part in self._list_of_block_partitions] + + if ( + isinstance(refs[0], _DeferredGetChunk) + and (refs[0].index == 0) + and all(prev.is_next_chunk(next) for prev, next in zip(refs[:-1], refs[1:])) + ): + self._chunk_lengths_cache = ( + None + if any(chunk.length is None for chunk in refs) + else [chunk.length for chunk in refs] ) - return cls._DEPLOY_AXIS_FUNC - @classmethod - def _get_deploy_split_func(cls): # noqa: GL08 - if cls._DEPLOY_SPLIT_FUNC is None: - cls._DEPLOY_SPLIT_FUNC = RayWrapper.put( - PandasDataframeAxisPartition.deploy_splitting_func + split: _DeferredSplit = refs[0].split + if split.num_splits == refs[-1].index: + # All the partitions are the chunks of the same DataFrame. Concatenation of + # all these chunks will get a df identical to the original one. Thus, we + # don't need to concatenate but can get the original one instead. + self._set_data_ref(split.non_split) + return + + # TODO: We have a subset of the same frame here and can just get a single chunk + # from the original frame instead of concatenating all these chunks. + + self._set_data_ref(self._concat(refs)) + + def _set_data_ref( + self, data: Union[DeferredExecution, ObjectRefType] + ): # noqa: GL08 + if isinstance(data, DeferredExecution): + data.subscribe() + self._data_ref = data + + def __del__(self): + """Unsubscribe from DeferredExecution.""" + if isinstance(self._data_ref, DeferredExecution): + self._data_ref.unsubscribe() + + @_inherit_docstrings(BaseDataframeAxisPartition.apply) + def apply( + self, + func, + *args, + num_splits=None, + other_axis_partition=None, + maintain_partitioning=True, + lengths=None, + manual_partition=False, + **kwargs, + ) -> Union[List[PandasOnRayDataframePartition], PandasOnRayDataframePartition]: + if not manual_partition: + if not self.full_axis: + # If this is not a full axis partition, it already contains a subset of + # the full axis, so we shouldn't split the result further. + num_splits = 1 + elif num_splits is None: + num_splits = self._num_splits + + if ( + num_splits == 1 + or not maintain_partitioning + or num_splits != self._num_splits + ): + lengths = None + elif lengths is None: + lengths = self._chunk_lengths + + if other_axis_partition is not None: + if isinstance(other_axis_partition, Collection): + if len(other_axis_partition) == 1: + other_part = other_axis_partition[0]._data_ref + else: + concat_fn = ( + PandasOnRayDataframeColumnPartition + if self.axis + else PandasOnRayDataframeRowPartition + )._concat + other_part = concat_fn([p._data_ref for p in other_axis_partition]) + else: + other_part = other_axis_partition._data_ref + args = [other_part] + list(args) + + de = self._apply(func, args, kwargs) + if num_splits > 1: + de = self._split(de, num_splits, lengths) + if not lengths or len(lengths) != num_splits: + lengths = [None] * num_splits + result = [ + PandasOnRayDataframePartition(_DeferredGetChunk(de, i, lengths[i])) + for i in range(num_splits) + ] + else: + result = [PandasOnRayDataframePartition(de)] + if self.full_axis or other_axis_partition is not None: + return result + else: + # If this is not a full axis partition, just take out the single split in the result. + return result[0] + + @_inherit_docstrings(PandasDataframeAxisPartition.add_to_apply_calls) + def add_to_apply_calls(self, func, *args, length=None, width=None, **kwargs): + de = self._apply(func, args, kwargs) + return type(self)( + de, self.full_axis, length, width, self._num_splits, self._chunk_lengths + ) + + @_inherit_docstrings(PandasDataframeAxisPartition.split) + def split( + self, split_func, num_splits, f_args=None, f_kwargs=None, extract_metadata=False + ) -> List[PandasOnRayDataframePartition]: + de = DeferredExecution( + self._data_ref, + split_func, + args=f_args, + kwargs=f_kwargs, + num_returns=num_splits, + ) + + if num_splits > 1: + return [ + PandasOnRayDataframePartition(DeferredGetItem(de, i)) + for i in range(num_splits) + ] + return [PandasOnRayDataframePartition(de)] + + @property + def _length_cache(self): # noqa: GL08 + return self._meta[self._meta_offset] + + @_length_cache.setter + def _length_cache(self, value): # noqa: GL08 + self._meta[self._meta_offset] = value + + def length(self, materialize=True): # noqa: GL08 + if self._length_cache is None: + self._calculate_lengths(materialize) + return self._length_cache + + @property + def _width_cache(self): # noqa: GL08 + return self._meta[self._meta_offset + 1] + + @_width_cache.setter + def _width_cache(self, value): # noqa: GL08 + self._meta[self._meta_offset + 1] = value + + def width(self, materialize=True): # noqa: GL08 + if self._width_cache is None: + self._calculate_lengths(materialize) + return self._width_cache + + def _calculate_lengths(self, materialize=True): # noqa: GL08 + if self._list_of_block_partitions is not None: + from . import PandasOnRayDataframePartitionManager + + lengths = [part.length(False) for part in self._list_of_block_partitions] + widths = [part.width(False) for part in self._list_of_block_partitions] + materialized = PandasOnRayDataframePartitionManager.materialize_futures( + lengths + widths ) - return cls._DEPLOY_SPLIT_FUNC + self._meta[self._meta_offset] = sum(materialized[: len(lengths)]) + self._meta[self._meta_offset + 1] = sum(materialized[len(lengths) :]) + else: + self.force_materialization() + if materialize: + self._meta.materialize() + + @_inherit_docstrings(PandasDataframeAxisPartition.drain_call_queue) + def drain_call_queue(self, num_splits=None): + if num_splits: + self._num_splits = num_splits + + @_inherit_docstrings(PandasDataframeAxisPartition.force_materialization) + def force_materialization(self, get_ip=False): + self._data # Trigger execution + self._num_splits = 1 + self._chunk_lengths_cache = None + self._list_of_block_partitions = None + return self + + @_inherit_docstrings(PandasDataframeAxisPartition.wait) + def wait(self): + """Wait completing computations on the object wrapped by the partition.""" + RayWrapper.wait(self._data) - @classmethod - def _get_drain_func(cls): # noqa: GL08 - if cls._DRAIN_FUNC is None: - cls._DRAIN_FUNC = RayWrapper.put(PandasDataframeAxisPartition.drain) - return cls._DRAIN_FUNC + @_inherit_docstrings(PandasDataframeAxisPartition.to_pandas) + def to_pandas(self): + return RayWrapper.materialize(self._data) + + @_inherit_docstrings(PandasDataframeAxisPartition.to_numpy) + def to_numpy(self): + return self.to_pandas().to_numpy() + + @_inherit_docstrings(PandasDataframeAxisPartition.mask) + def mask(self, row_indices, col_indices): + part = PandasOnRayDataframePartition(self._data_ref).mask( + row_indices, col_indices + ) + return type(self)(part, False) + + @property + @_inherit_docstrings(BaseDataframeAxisPartition.list_of_blocks) + def list_of_blocks(self): + return [part._data for part in self.list_of_block_partitions] + + @property + @_inherit_docstrings(PandasDataframeAxisPartition.list_of_block_partitions) + def list_of_block_partitions(self) -> list: + if self._list_of_block_partitions is not None: + return self._list_of_block_partitions + + data = self._data_ref + num_splits = self._num_splits + if num_splits > 1: + lengths = self._chunk_lengths + data = self._split(data, num_splits, lengths) + if not lengths or len(lengths) != num_splits: + lengths = [None] * num_splits + self._list_of_block_partitions = [ + PandasOnRayDataframePartition(_DeferredGetChunk(data, i, lengths[i])) + for i in range(num_splits) + ] + else: + self._list_of_block_partitions = [PandasOnRayDataframePartition(data)] + return self._list_of_block_partitions @property def list_of_ips(self): """ - Get the IPs holding the physical objects composing this partition. + Return the list of IP worker addresses. Returns ------- - List - A list of IPs as ``ray.ObjectRef`` or str. + list of str """ - # Defer draining call queue until we get the ip address - result = [None] * len(self.list_of_block_partitions) - for idx, partition in enumerate(self.list_of_block_partitions): - partition.drain_call_queue() - result[idx] = partition.ip(materialize=False) - return result + if (ip := self._meta[self._meta_offset + 2]) is not None: + return [ip] + if self._list_of_block_partitions is not None: + return [part.ip() for part in self._list_of_block_partitions] + return [] - @classmethod - @_inherit_docstrings(PandasDataframeAxisPartition.deploy_splitting_func) - def deploy_splitting_func( - cls, - axis, - func, - f_args, - f_kwargs, - num_splits, - *partitions, - extract_metadata=False, - ): - return _deploy_ray_func.options( - num_returns=( - num_splits * (1 + cls._PARTITIONS_METADATA_LEN) - if extract_metadata - else num_splits - ), - ).remote( - cls._get_deploy_split_func(), - *f_args, - num_splits, - *partitions, - axis=axis, - f_to_deploy=func, - f_len_args=len(f_args), - f_kwargs=f_kwargs, - extract_metadata=extract_metadata, - ) + @property + def _data(self): # noqa: GL08 + data = self._data_ref + if isinstance(data, DeferredExecution): + data, self._meta, self._meta_offset = data.exec() + self._data_ref = data + return data + + @property + def _chunk_lengths(self): # noqa: GL08 + if ( + self._chunk_lengths_cache is None + and self._list_of_block_partitions is not None + ): + attr = "length" if self.axis == 0 else "width" + self._chunk_lengths_cache = [ + getattr(p, attr)(materialize=False) + for p in self._list_of_block_partitions + ] + return self._chunk_lengths_cache @classmethod - def deploy_axis_func( - cls, - axis, - func, - f_args, - f_kwargs, - num_splits, - maintain_partitioning, - *partitions, - lengths=None, - manual_partition=False, - max_retries=None, - ): - """ - Deploy a function along a full axis. - - Parameters - ---------- - axis : {0, 1} - The axis to perform the function along. - func : callable - The function to perform. - f_args : list or tuple - Positional arguments to pass to ``func``. - f_kwargs : dict - Keyword arguments to pass to ``func``. - num_splits : int - The number of splits to return (see ``split_result_of_axis_func_pandas``). - maintain_partitioning : bool - If True, keep the old partitioning if possible. - If False, create a new partition layout. - *partitions : iterable - All partitions that make up the full axis (row or column). - lengths : list, optional - The list of lengths to shuffle the object. - manual_partition : bool, default: False - If True, partition the result with `lengths`. - max_retries : int, default: None - The max number of times to retry the func. + def _concat(cls, data): # noqa: GL08 + if (fn := getattr(cls, "_CONCAT_FN", None)) is None: - Returns - ------- - list - A list of ``ray.ObjectRef``-s. - """ - return _deploy_ray_func.options( - num_returns=(num_splits if lengths is None else len(lengths)) - * (1 + cls._PARTITIONS_METADATA_LEN), - **({"max_retries": max_retries} if max_retries is not None else {}), - ).remote( - cls._get_deploy_axis_func(), - *f_args, - num_splits, - maintain_partitioning, - *partitions, - axis=axis, - f_to_deploy=func, - f_len_args=len(f_args), - f_kwargs=f_kwargs, - manual_partition=manual_partition, - lengths=lengths, - return_generator=True, - ) + def concat(dfs, axis=cls.axis): # pragma: no cover + return pandas.concat(dfs, axis=axis, copy=False) + + cls._CONCAT_FN = fn = RayWrapper.put(concat) + return DeferredExecution(data, fn) + + def _apply(self, apply_fn, args, kwargs) -> DeferredExecution: # noqa: GL08 + return DeferredExecution(self._data_ref, apply_fn, args, kwargs) @classmethod - def deploy_func_between_two_axis_partitions( - cls, - axis, - func, - f_args, - f_kwargs, - num_splits, - len_of_left, - other_shape, - *partitions, + def _split( + cls, data: ObjectRefOrDeType, num_splits: int, lengths: Optional[List[int]] + ) -> "_DeferredSplit": # noqa: GL08 + if (fn := getattr(cls, "_SPLIT_FN", None)) is None: + + def split( + df: pandas.DataFrame, + num_splits: int, + min_chunk_len: int, + skip_chunks: Set[int], + *lengths: Optional[List[int]], + axis: int = cls.axis, + ): # pragma: no cover + if not lengths or (sum(lengths) != df.shape[axis]): + length = df.shape[axis] + chunk_len = max(math.ceil(length / num_splits), min_chunk_len) + lengths = [chunk_len] * num_splits + + result = [] + start = 0 + for i in range(num_splits): + if i in skip_chunks: + result.append(None) + start += lengths[i] + continue + + end = start + lengths[i] + chunk = df.iloc[start:end] if axis == 0 else df.iloc[:, start:end] + start = end + result.append(chunk) + if isinstance(chunk.axes[axis], pandas.MultiIndex): + chunk.set_axis( + chunk.axes[axis].remove_unused_levels(), + axis=axis, + copy=False, + ) + + return result + + cls._SPLIT_FN = fn = RayWrapper.put(split) + return _DeferredSplit(data, fn, num_splits, lengths) + + +class _DeferredSplit(DeferredExecution): # noqa: GL08 + def __init__( + self, + non_split: ObjectRefOrDeType, + func: ObjectRefType, + num_splits: int, + lengths: Optional[List[int]], ): - """ - Deploy a function along a full axis between two data sets. - - Parameters - ---------- - axis : {0, 1} - The axis to perform the function along. - func : callable - The function to perform. - f_args : list or tuple - Positional arguments to pass to ``func``. - f_kwargs : dict - Keyword arguments to pass to ``func``. - num_splits : int - The number of splits to return (see ``split_result_of_axis_func_pandas``). - len_of_left : int - The number of values in `partitions` that belong to the left data set. - other_shape : np.ndarray - The shape of right frame in terms of partitions, i.e. - (other_shape[i-1], other_shape[i]) will indicate slice to restore i-1 axis partition. - *partitions : iterable - All partitions that make up the full axis (row or column) for both data sets. - - Returns - ------- - list - A list of ``ray.ObjectRef``-s. - """ - return _deploy_ray_func.options( - num_returns=num_splits * (1 + cls._PARTITIONS_METADATA_LEN) - ).remote( - PandasDataframeAxisPartition.deploy_func_between_two_axis_partitions, - *f_args, - num_splits, - len_of_left, - other_shape, - *partitions, - axis=axis, - f_to_deploy=func, - f_len_args=len(f_args), - f_kwargs=f_kwargs, - return_generator=True, + self.non_split = non_split + self.num_splits = num_splits + self.skip_chunks = set() + args = [num_splits, MinPartitionSize.get(), self.skip_chunks] + if lengths and (len(lengths) == num_splits): + args.extend(lengths) + super().__init__(non_split, func, args, num_returns=num_splits) + + +class _DeferredGetChunk(DeferredGetItem): # noqa: GL08 + def __init__(self, split: _DeferredSplit, index: int, length: Optional[int] = None): + super().__init__(split, index) + self.split = split + self.length = length + + def __del__(self): + """Remove this chunk from _DeferredSplit if it's not executed yet.""" + if self.data is self.split: + self.split.skip_chunks.add(self.index) + + def is_next_chunk(self, other): # noqa: GL08 + return ( + isinstance(other, _DeferredGetChunk) + and (self.split is other.split) + and (other.index == self.index + 1) ) - def wait(self): - """Wait completing computations on the object wrapped by the partition.""" - self.drain_call_queue() - futures = self.list_of_blocks - RayWrapper.wait(futures) - @_inherit_docstrings(PandasOnRayDataframeVirtualPartition.__init__) class PandasOnRayDataframeColumnPartition(PandasOnRayDataframeVirtualPartition): @@ -263,75 +458,3 @@ class PandasOnRayDataframeColumnPartition(PandasOnRayDataframeVirtualPartition): @_inherit_docstrings(PandasOnRayDataframeVirtualPartition.__init__) class PandasOnRayDataframeRowPartition(PandasOnRayDataframeVirtualPartition): axis = 1 - - -@ray.remote -def _deploy_ray_func( - deployer, - *positional_args, - axis, - f_to_deploy, - f_len_args, - f_kwargs, - extract_metadata=True, - **kwargs, -): # pragma: no cover - """ - Execute a function on an axis partition in a worker process. - - This is ALWAYS called on either ``PandasDataframeAxisPartition.deploy_axis_func`` - or ``PandasDataframeAxisPartition.deploy_func_between_two_axis_partitions``, which both - serve to deploy another dataframe function on a Ray worker process. The provided `positional_args` - contains positional arguments for both: `deployer` and for `f_to_deploy`, the parameters can be separated - using the `f_len_args` value. The parameters are combined so they will be deserialized by Ray before the - kernel is executed (`f_kwargs` will never contain more Ray objects, and thus does not require deserialization). - - Parameters - ---------- - deployer : callable - A `PandasDataFrameAxisPartition.deploy_*` method that will call ``f_to_deploy``. - *positional_args : list - The first `f_len_args` elements in this list represent positional arguments - to pass to the `f_to_deploy`. The rest are positional arguments that will be - passed to `deployer`. - axis : {0, 1} - The axis to perform the function along. This argument is keyword only. - f_to_deploy : callable or RayObjectID - The function to deploy. This argument is keyword only. - f_len_args : int - Number of positional arguments to pass to ``f_to_deploy``. This argument is keyword only. - f_kwargs : dict - Keyword arguments to pass to ``f_to_deploy``. This argument is keyword only. - extract_metadata : bool, default: True - Whether to return metadata (length, width, ip) of the result. Passing `False` may relax - the load on object storage as the remote function would return 4 times fewer futures. - Passing `False` makes sense for temporary results where you know for sure that the - metadata will never be requested. This argument is keyword only. - **kwargs : dict - Keyword arguments to pass to ``deployer``. - - Returns - ------- - list : Union[tuple, list] - The result of the function call, and metadata for it. - - Notes - ----- - Ray functions are not detected by codecov (thus pragma: no cover). - """ - f_args = positional_args[:f_len_args] - deploy_args = positional_args[f_len_args:] - result = deployer(axis, f_to_deploy, f_args, f_kwargs, *deploy_args, **kwargs) - - if not extract_metadata: - for item in result: - yield item - else: - ip = get_node_ip_address() - for r in result: - if isinstance(r, pandas.DataFrame): - for item in [r, len(r), len(r.columns), ip]: - yield item - else: - for item in [r, None, None, ip]: - yield item diff --git a/modin/experimental/batch/pipeline.py b/modin/experimental/batch/pipeline.py index d7cd60cae47..b17fb1d7fd1 100644 --- a/modin/experimental/batch/pipeline.py +++ b/modin/experimental/batch/pipeline.py @@ -242,13 +242,14 @@ def _complete_nodes(self, list_of_nodes, partitions): ) new_dfs[-1].drain_call_queue(num_splits=1) - def reducer(df): + def reducer(df, reduce_fn, *parts): df_inputs = [df] - for df in new_dfs: - df_inputs.append(df.to_pandas()) - return node.reduce_fn(df_inputs) + df_inputs.extend(parts) + return reduce_fn(df_inputs) - partitions = [partitions[0].add_to_apply_calls(reducer)] + args = [node.reduce_fn] + args.extend([dfs._data_ref for dfs in new_dfs]) + partitions = [partitions[0].add_to_apply_calls(reducer, *args)] elif node.repartition_after: if len(partitions) > 1: ErrorMessage.not_implemented( diff --git a/modin/pandas/test/dataframe/test_binary.py b/modin/pandas/test/dataframe/test_binary.py index 48c0afd3318..8d52860004c 100644 --- a/modin/pandas/test/dataframe/test_binary.py +++ b/modin/pandas/test/dataframe/test_binary.py @@ -18,8 +18,8 @@ import modin.pandas as pd from modin.config import Engine, NPartitions, StorageFormat -from modin.core.dataframe.pandas.partitioning.axis_partition import ( - PandasDataframeAxisPartition, +from modin.core.dataframe.base.partitioning.axis_partition import ( + BaseDataframeAxisPartition, ) from modin.pandas.test.utils import ( CustomIntegerForAddition, @@ -228,7 +228,7 @@ def modin_df(is_virtual): # Modin should rebalance the partitions after the concat, producing virtual partitions. assert isinstance( result._query_compiler._modin_frame._partitions[0][0], - PandasDataframeAxisPartition, + BaseDataframeAxisPartition, ) return result diff --git a/modin/pandas/test/test_groupby.py b/modin/pandas/test/test_groupby.py index 5113ed04e2f..73886758d01 100644 --- a/modin/pandas/test/test_groupby.py +++ b/modin/pandas/test/test_groupby.py @@ -28,8 +28,8 @@ StorageFormat, ) from modin.core.dataframe.algebra.default2pandas.groupby import GroupBy -from modin.core.dataframe.pandas.partitioning.axis_partition import ( - PandasDataframeAxisPartition, +from modin.core.dataframe.base.partitioning.axis_partition import ( + BaseDataframeAxisPartition, ) from modin.pandas.io import from_pandas from modin.pandas.utils import is_scalar @@ -264,6 +264,9 @@ def test_mixed_dtypes_groupby(as_index): ), ) # FIXME: https://github.com/modin-project/modin/issues/7032 + # Triger execution of deferred operations. If not executed, eval_shift() below fails with + # `could not convert string to float: '\x94'`. Probably, this is also related to #7032. + modin_groupby.shift() eval_general( modin_groupby, pandas_groupby, @@ -2626,7 +2629,7 @@ def test_groupby_with_virtual_partitions(): # Check that the constructed Modin DataFrame has virtual partitions when assert issubclass( type(big_modin_df._query_compiler._modin_frame._partitions[0][0]), - PandasDataframeAxisPartition, + BaseDataframeAxisPartition, ) eval_general( big_modin_df, big_pandas_df, lambda df: df.groupby(df.columns[0]).count()