Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PERF-#7397: Avoid materializing index/columns in shape checks #7398

Merged
merged 5 commits into from
Sep 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion modin/core/storage_formats/base/query_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import abc
import warnings
from functools import cached_property
from typing import TYPE_CHECKING, Hashable, List, Optional
from typing import TYPE_CHECKING, Hashable, List, Literal, Optional

import numpy as np
import pandas
Expand Down Expand Up @@ -4270,6 +4270,24 @@ def get_axis(self, axis):
"""
return self.index if axis == 0 else self.columns

def get_axis_len(self, axis: Literal[0, 1]) -> int:
"""
Return the length of the specified axis.

A query compiler may choose to override this method if it has a more efficient way
of computing the length of an axis without materializing it.

Parameters
----------
axis : {0, 1}
Axis to return labels on.

Returns
-------
int
"""
return len(self.get_axis(axis))

def take_2d_labels(
self,
index,
Expand Down
20 changes: 19 additions & 1 deletion modin/core/storage_formats/pandas/query_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import re
import warnings
from collections.abc import Iterable
from typing import TYPE_CHECKING, Hashable, List, Optional
from typing import TYPE_CHECKING, Hashable, List, Literal, Optional

import numpy as np
import pandas
Expand Down Expand Up @@ -395,6 +395,24 @@ def from_dataframe(cls, df, data_cls):
index: pandas.Index = property(_get_axis(0), _set_axis(0))
columns: pandas.Index = property(_get_axis(1), _set_axis(1))

def get_axis_len(self, axis: Literal[0, 1]) -> int:
"""
Return the length of the specified axis.

Parameters
----------
axis : {0, 1}
Axis to return labels on.

Returns
-------
int
"""
if axis == 0:
return len(self._modin_frame)
else:
return sum(self._modin_frame.column_widths)

@property
def dtypes(self) -> pandas.Series:
return self._modin_frame.dtypes
Expand Down
12 changes: 7 additions & 5 deletions modin/pandas/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,7 +288,9 @@ def _build_repr_df(
A pandas dataset with `num_rows` or fewer rows and `num_cols` or fewer columns.
"""
# Fast track for empty dataframe.
if len(self.index) == 0 or (self._is_dataframe and len(self.columns) == 0):
if len(self) == 0 or (
self._is_dataframe and self._query_compiler.get_axis_len(1) == 0
):
return pandas.DataFrame(
index=self.index,
columns=self.columns if self._is_dataframe else None,
Expand Down Expand Up @@ -1004,7 +1006,7 @@ def error_raiser(msg, exception):
return result._query_compiler
return result
elif isinstance(func, dict):
if len(self.columns) != len(set(self.columns)):
if self._query_compiler.get_axis_len(1) != len(set(self.columns)):
warnings.warn(
"duplicate column names not supported with apply().",
FutureWarning,
Expand Down Expand Up @@ -2860,7 +2862,7 @@ def sample(
axis_length = len(axis_labels)
else:
# Getting rows requires indices instead of labels. RangeIndex provides this.
axis_labels = pandas.RangeIndex(len(self.index))
axis_labels = pandas.RangeIndex(len(self))
axis_length = len(axis_labels)
if weights is not None:
# Index of the weights Series should correspond to the index of the
Expand Down Expand Up @@ -3217,7 +3219,7 @@ def tail(self, n=5) -> Self: # noqa: PR01, RT01, D200
"""
if n != 0:
return self.iloc[-n:]
return self.iloc[len(self.index) :]
return self.iloc[len(self) :]

def take(self, indices, axis=0, **kwargs) -> Self: # noqa: PR01, RT01, D200
"""
Expand Down Expand Up @@ -4149,7 +4151,7 @@ def __len__(self) -> int:
-------
int
"""
return len(self.index)
return self._query_compiler.get_axis_len(0)

@_doc_binary_op(
operation="less than comparison",
Expand Down
53 changes: 30 additions & 23 deletions modin/pandas/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,13 +268,15 @@ def __repr__(self) -> str:
-------
str
"""
num_rows = pandas.get_option("display.max_rows") or len(self.index)
num_cols = pandas.get_option("display.max_columns") or len(self.columns)
num_rows = pandas.get_option("display.max_rows") or len(self)
num_cols = pandas.get_option(
"display.max_columns"
) or self._query_compiler.get_axis_len(1)
result = repr(self._build_repr_df(num_rows, num_cols))
if len(self.index) > num_rows or len(self.columns) > num_cols:
if len(self) > num_rows or self._query_compiler.get_axis_len(1) > num_cols:
# The split here is so that we don't repr pandas row lengths.
return result.rsplit("\n\n", 1)[0] + "\n\n[{0} rows x {1} columns]".format(
len(self.index), len(self.columns)
*self.shape
)
else:
return result
Expand All @@ -293,13 +295,11 @@ def _repr_html_(self) -> str: # pragma: no cover
# We use pandas _repr_html_ to get a string of the HTML representation
# of the dataframe.
result = self._build_repr_df(num_rows, num_cols)._repr_html_()
if len(self.index) > num_rows or len(self.columns) > num_cols:
if len(self) > num_rows or self._query_compiler.get_axis_len(1) > num_cols:
# We split so that we insert our correct dataframe dimensions.
return result.split("<p>")[
0
] + "<p>{0} rows x {1} columns</p>\n</div>".format(
len(self.index), len(self.columns)
)
] + "<p>{0} rows x {1} columns</p>\n</div>".format(*self.shape)
else:
return result

Expand Down Expand Up @@ -365,7 +365,7 @@ def empty(self) -> bool: # noqa: RT01, D200
"""
Indicate whether ``DataFrame`` is empty.
"""
return len(self.columns) == 0 or len(self.index) == 0
return self._query_compiler.get_axis_len(1) == 0 or len(self) == 0

@property
def axes(self) -> list[pandas.Index]: # noqa: RT01, D200
Expand All @@ -379,7 +379,7 @@ def shape(self) -> tuple[int, int]: # noqa: RT01, D200
"""
Return a tuple representing the dimensionality of the ``DataFrame``.
"""
return len(self.index), len(self.columns)
return len(self), self._query_compiler.get_axis_len(1)

def add_prefix(self, prefix, axis=None) -> DataFrame: # noqa: PR01, RT01, D200
"""
Expand Down Expand Up @@ -781,7 +781,9 @@ def dot(self, other) -> Union[DataFrame, Series]: # noqa: PR01, RT01, D200
"""
if isinstance(other, BasePandasDataset):
common = self.columns.union(other.index)
if len(common) > len(self.columns) or len(common) > len(other.index):
if len(common) > self._query_compiler.get_axis_len(1) or len(common) > len(
other
):
raise ValueError("Matrices are not aligned")

qc = other.reindex(index=common)._query_compiler
Expand Down Expand Up @@ -1084,7 +1086,7 @@ def insert(
+ f"{len(value.columns)} columns instead."
)
value = value.squeeze(axis=1)
if not self._query_compiler.lazy_row_count and len(self.index) == 0:
if not self._query_compiler.lazy_row_count and len(self) == 0:
if not hasattr(value, "index"):
try:
value = pandas.Series(value)
Expand All @@ -1099,7 +1101,7 @@ def insert(
new_query_compiler = self.__constructor__(
value, index=new_index, columns=new_columns
)._query_compiler
elif len(self.columns) == 0 and loc == 0:
elif self._query_compiler.get_axis_len(1) == 0 and loc == 0:
new_index = self.index
new_query_compiler = self.__constructor__(
data=value,
Expand All @@ -1110,18 +1112,19 @@ def insert(
if (
is_list_like(value)
and not isinstance(value, (pandas.Series, Series))
and len(value) != len(self.index)
and len(value) != len(self)
):
raise ValueError(
"Length of values ({}) does not match length of index ({})".format(
len(value), len(self.index)
len(value), len(self)
)
)
if allow_duplicates is not True and column in self.columns:
raise ValueError(f"cannot insert {column}, already exists")
if not -len(self.columns) <= loc <= len(self.columns):
columns_len = self._query_compiler.get_axis_len(1)
if not -columns_len <= loc <= columns_len:
raise IndexError(
f"index {loc} is out of bounds for axis 0 with size {len(self.columns)}"
f"index {loc} is out of bounds for axis 0 with size {columns_len}"
)
elif loc < 0:
raise ValueError("unbounded slice")
Expand Down Expand Up @@ -2074,9 +2077,11 @@ def squeeze(
Squeeze 1 dimensional axis objects into scalars.
"""
axis = self._get_axis_number(axis) if axis is not None else None
if axis is None and (len(self.columns) == 1 or len(self) == 1):
if axis is None and (
self._query_compiler.get_axis_len(1) == 1 or len(self) == 1
):
return Series(query_compiler=self._query_compiler).squeeze()
if axis == 1 and len(self.columns) == 1:
if axis == 1 and self._query_compiler.get_axis_len(1) == 1:
self._query_compiler._shape_hint = "column"
return Series(query_compiler=self._query_compiler)
if axis == 0 and len(self) == 1:
Expand Down Expand Up @@ -2671,7 +2676,7 @@ def __setitem__(self, key, value) -> None:
return self._setitem_slice(key, value)

if hashable(key) and key not in self.columns:
if isinstance(value, Series) and len(self.columns) == 0:
if isinstance(value, Series) and self._query_compiler.get_axis_len(1) == 0:
# Note: column information is lost when assigning a query compiler
prev_index = self.columns
self._query_compiler = value._query_compiler.copy()
Expand All @@ -2680,7 +2685,9 @@ def __setitem__(self, key, value) -> None:
self.columns = prev_index.insert(0, key)
return
# Do new column assignment after error checks and possible value modifications
self.insert(loc=len(self.columns), column=key, value=value)
self.insert(
loc=self._query_compiler.get_axis_len(1), column=key, value=value
)
return

if not hashable(key):
Expand Down Expand Up @@ -2756,7 +2763,7 @@ def __setitem__(self, key, value) -> None:

new_qc = self._query_compiler.insert_item(
axis=1,
loc=len(self.columns),
loc=self._query_compiler.get_axis_len(1),
value=value._query_compiler,
how="left",
)
Expand All @@ -2783,7 +2790,7 @@ def setitem_unhashable_key(df, value):
if not isinstance(value, (Series, Categorical, np.ndarray, list, range)):
value = list(value)

if not self._query_compiler.lazy_row_count and len(self.index) == 0:
if not self._query_compiler.lazy_row_count and len(self) == 0:
new_self = self.__constructor__({key: value}, columns=self.columns)
self._update_inplace(new_self._query_compiler)
else:
Expand Down
14 changes: 7 additions & 7 deletions modin/pandas/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,8 +445,8 @@ def __repr__(self) -> str:
name_str = "Name: {}, ".format(str(self.name))
else:
name_str = ""
if len(self.index) > num_rows:
len_str = "Length: {}, ".format(len(self.index))
if len(self) > num_rows:
len_str = "Length: {}, ".format(len(self))
else:
len_str = ""
dtype_str = "dtype: {}".format(
Expand Down Expand Up @@ -966,7 +966,7 @@ def dot(self, other) -> Union[Series, np.ndarray]: # noqa: PR01, RT01, D200
"""
if isinstance(other, BasePandasDataset):
common = self.index.union(other.index)
if len(common) > len(self.index) or len(common) > len(other.index):
if len(common) > len(self) or len(common) > len(other):
raise ValueError("Matrices are not aligned")

qc = other.reindex(index=common)._query_compiler
Expand Down Expand Up @@ -1761,7 +1761,7 @@ def reset_index(
name = 0 if self.name is None else self.name

if drop and level is None:
new_idx = pandas.RangeIndex(len(self.index))
new_idx = pandas.RangeIndex(len(self))
if inplace:
self.index = new_idx
else:
Expand Down Expand Up @@ -1989,7 +1989,7 @@ def squeeze(self, axis=None) -> Union[Series, Scalar]: # noqa: PR01, RT01, D200
if axis is not None:
# Validate `axis`
pandas.Series._get_axis_number(axis)
if len(self.index) == 1:
if len(self) == 1:
return self._reduce_dimension(self._query_compiler)
else:
return self.copy()
Expand Down Expand Up @@ -2307,7 +2307,7 @@ def empty(self) -> bool: # noqa: RT01, D200
"""
Indicate whether Series is empty.
"""
return len(self.index) == 0
return len(self) == 0

@property
def hasnans(self) -> bool: # noqa: RT01, D200
Expand Down Expand Up @@ -2648,7 +2648,7 @@ def _getitem(self, key) -> Union[Series, Scalar]:
if is_bool_indexer(key):
return self.__constructor__(
query_compiler=self._query_compiler.getitem_row_array(
pandas.RangeIndex(len(self.index))[key]
pandas.RangeIndex(len(self))[key]
)
)
# TODO: More efficiently handle `tuple` case for `Series.__getitem__`
Expand Down
Loading