Skip to content

Commit

Permalink
PERF-#5268: Call get on all partitions at once in to_pandas (#4776)
Browse files Browse the repository at this point in the history
Co-authored-by: Vasily Litvinov <[email protected]>
Co-authored-by: Dmitry Chigarev <[email protected]>
Signed-off-by: Myachev <[email protected]>
  • Loading branch information
3 people authored Nov 27, 2022
1 parent 8f6e642 commit 9534478
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 1 deletion.
19 changes: 18 additions & 1 deletion modin/core/dataframe/pandas/partitioning/partition_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,7 +641,24 @@ def to_pandas(cls, partitions):
pandas.DataFrame
A pandas DataFrame
"""
retrieved_objects = [[obj.to_pandas() for obj in part] for part in partitions]
retrieved_objects = cls.get_objects_from_partitions(partitions.flatten())
if all(
isinstance(obj, (pandas.DataFrame, pandas.Series))
for obj in retrieved_objects
):
height, width, *_ = tuple(partitions.shape) + (0,)
# restore 2d array
objs = iter(retrieved_objects)
retrieved_objects = [
[next(objs) for _ in range(width)] for __ in range(height)
]
else:
# Partitions do not always contain pandas objects, for example, hdk uses pyarrow tables.
# This implementation comes from the fact that calling `partition.get`
# function is not always equivalent to `partition.to_pandas`.
retrieved_objects = [
[obj.to_pandas() for obj in part] for part in partitions
]
if all(
isinstance(part, pandas.Series) for row in retrieved_objects for part in row
):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ def get_objects_from_partitions(cls, partitions):
list
The objects wrapped by `partitions`.
"""
for idx, part in enumerate(partitions):
if hasattr(part, "force_materialization"):
partitions[idx] = part.force_materialization()
assert all(
[len(partition.list_of_blocks) == 1 for partition in partitions]
), "Implementation assumes that each partition contains a signle block."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,9 @@ def get_objects_from_partitions(cls, partitions):
list
The objects wrapped by `partitions`.
"""
for idx, part in enumerate(partitions):
if hasattr(part, "force_materialization"):
partitions[idx] = part.force_materialization()
assert all(
[len(partition.list_of_blocks) == 1 for partition in partitions]
), "Implementation assumes that each partition contains a signle block."
Expand Down

0 comments on commit 9534478

Please sign in to comment.