From 5b25fb87f180855bc0435bcad3461acba1c5abf2 Mon Sep 17 00:00:00 2001 From: Stijn de Gooijer Date: Sun, 26 May 2024 12:35:57 +0200 Subject: [PATCH] refactor(python): Refactor `Series/DataFrame.__getitem__` logic (#16482) --- .../polars/_utils/construction/__init__.py | 2 - py-polars/polars/_utils/construction/other.py | 53 --- py-polars/polars/_utils/getitem.py | 432 ++++++++++++++++++ py-polars/polars/dataframe/frame.py | 247 ++-------- py-polars/polars/series/series.py | 114 +---- py-polars/polars/type_aliases.py | 31 +- py-polars/src/dataframe/general.rs | 29 +- py-polars/src/series/mod.rs | 17 +- py-polars/tests/unit/dataframe/test_df.py | 194 +------- .../tests/unit/dataframe/test_getitem.py | 332 +++++++++++++- py-polars/tests/unit/series/test_getitem.py | 73 +++ py-polars/tests/unit/test_errors.py | 10 +- 12 files changed, 964 insertions(+), 570 deletions(-) create mode 100644 py-polars/polars/_utils/getitem.py diff --git a/py-polars/polars/_utils/construction/__init__.py b/py-polars/polars/_utils/construction/__init__.py index 35f2232b2e0e..1b9a543bfb6d 100644 --- a/py-polars/polars/_utils/construction/__init__.py +++ b/py-polars/polars/_utils/construction/__init__.py @@ -10,7 +10,6 @@ ) from polars._utils.construction.other import ( coerce_arrow, - numpy_to_idxs, pandas_series_to_arrow, ) from polars._utils.construction.series import ( @@ -43,6 +42,5 @@ "series_to_pyseries", # other "coerce_arrow", - "numpy_to_idxs", "pandas_series_to_arrow", ] diff --git a/py-polars/polars/_utils/construction/other.py b/py-polars/polars/_utils/construction/other.py index c5a2a06bb7b5..4f362a772ca6 100644 --- a/py-polars/polars/_utils/construction/other.py +++ b/py-polars/polars/_utils/construction/other.py @@ -2,66 +2,13 @@ from typing import TYPE_CHECKING, Any -import polars._reexport as pl from polars._utils.construction.utils import get_first_non_none -from polars.datatypes import UInt32 -from polars.dependencies import numpy as np from polars.dependencies import pyarrow as pa -from polars.meta import get_index_type if TYPE_CHECKING: - from polars import Series from polars.dependencies import pandas as pd -def numpy_to_idxs(idxs: np.ndarray[Any, Any], size: int) -> Series: - # Unsigned or signed Numpy array (ordered from fastest to slowest). - # - np.uint32 (polars) or np.uint64 (polars_u64_idx) numpy array - # indexes. - # - Other unsigned numpy array indexes are converted to pl.UInt32 - # (polars) or pl.UInt64 (polars_u64_idx). - # - Signed numpy array indexes are converted pl.UInt32 (polars) or - # pl.UInt64 (polars_u64_idx) after negative indexes are converted - # to absolute indexes. - if idxs.ndim != 1: - msg = "only 1D numpy array is supported as index" - raise ValueError(msg) - - idx_type = get_index_type() - - if len(idxs) == 0: - return pl.Series("", [], dtype=idx_type) - - # Numpy array with signed or unsigned integers. - if idxs.dtype.kind not in ("i", "u"): - msg = "unsupported idxs datatype" - raise NotImplementedError(msg) - - if idx_type == UInt32: - if idxs.dtype in {np.int64, np.uint64} and idxs.max() >= 2**32: - msg = "index positions should be smaller than 2^32" - raise ValueError(msg) - if idxs.dtype == np.int64 and idxs.min() < -(2**32): - msg = "index positions should be bigger than -2^32 + 1" - raise ValueError(msg) - - if idxs.dtype.kind == "i" and idxs.min() < 0: - if idx_type == UInt32: - if idxs.dtype in (np.int8, np.int16): - idxs = idxs.astype(np.int32) - else: - if idxs.dtype in (np.int8, np.int16, np.int32): - idxs = idxs.astype(np.int64) - - # Update negative indexes to absolute indexes. - idxs = np.where(idxs < 0, size + idxs, idxs) - - # numpy conversion is much faster - idxs = idxs.astype(np.uint32) if idx_type == UInt32 else idxs.astype(np.uint64) - - return pl.Series("", idxs, dtype=idx_type) - - def pandas_series_to_arrow( values: pd.Series[Any] | pd.Index[Any], *, diff --git a/py-polars/polars/_utils/getitem.py b/py-polars/polars/_utils/getitem.py new file mode 100644 index 000000000000..84f80a10456e --- /dev/null +++ b/py-polars/polars/_utils/getitem.py @@ -0,0 +1,432 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Iterable, NoReturn, Sequence, overload + +import polars._reexport as pl +import polars.functions as F +from polars._utils.constants import U32_MAX +from polars._utils.various import range_to_slice +from polars.datatypes.classes import ( + Boolean, + Int8, + Int16, + Int32, + Int64, + String, + UInt32, + UInt64, +) +from polars.dependencies import _check_for_numpy +from polars.dependencies import numpy as np +from polars.meta.index_type import get_index_type +from polars.slice import PolarsSlice + +if TYPE_CHECKING: + from polars import DataFrame, Series + from polars.type_aliases import ( + MultiColSelector, + MultiIndexSelector, + SingleColSelector, + SingleIndexSelector, + ) + +__all__ = [ + "get_df_item_by_key", + "get_series_item_by_key", +] + + +@overload +def get_series_item_by_key(s: Series, key: SingleIndexSelector) -> Any: ... + + +@overload +def get_series_item_by_key(s: Series, key: MultiIndexSelector) -> Series: ... + + +def get_series_item_by_key( + s: Series, key: SingleIndexSelector | MultiIndexSelector +) -> Any | Series: + """Select one or more elements from the Series.""" + if isinstance(key, int): + return s._s.get_index_signed(key) + + elif isinstance(key, slice): + return _select_elements_by_slice(s, key) + + elif isinstance(key, range): + key = range_to_slice(key) + return _select_elements_by_slice(s, key) + + elif isinstance(key, Sequence): + if not key: + return s.clear() + if isinstance(key[0], bool): + _raise_on_boolean_mask() + indices = pl.Series("", key, dtype=Int64) + indices = _convert_series_to_indices(indices, s.len()) + return _select_elements_by_index(s, indices) + + elif isinstance(key, pl.Series): + indices = _convert_series_to_indices(key, s.len()) + return _select_elements_by_index(s, indices) + + elif _check_for_numpy(key) and isinstance(key, np.ndarray): + indices = _convert_np_ndarray_to_indices(key, s.len()) + return _select_elements_by_index(s, indices) + + msg = f"cannot select elements using key of type {type(key).__name__!r}: {key!r}" + raise TypeError(msg) + + +def _select_elements_by_slice(s: Series, key: slice) -> Series: + return PolarsSlice(s).apply(key) # type: ignore[return-value] + + +def _select_elements_by_index(s: Series, key: Series) -> Series: + return s._from_pyseries(s._s.gather_with_series(key._s)) + + +# `str` overlaps with `Sequence[str]` +# We can ignore this but we must keep this overload ordering +@overload +def get_df_item_by_key( + df: DataFrame, key: tuple[SingleIndexSelector, SingleColSelector] +) -> Any: ... + + +@overload +def get_df_item_by_key( # type: ignore[overload-overlap] + df: DataFrame, key: str | tuple[MultiIndexSelector, SingleColSelector] +) -> Series: ... + + +@overload +def get_df_item_by_key( + df: DataFrame, + key: ( + SingleIndexSelector + | MultiIndexSelector + | MultiColSelector + | tuple[SingleIndexSelector, MultiColSelector] + | tuple[MultiIndexSelector, MultiColSelector] + ), +) -> DataFrame: ... + + +def get_df_item_by_key( + df: DataFrame, + key: ( + SingleIndexSelector + | SingleColSelector + | MultiColSelector + | MultiIndexSelector + | tuple[SingleIndexSelector, SingleColSelector] + | tuple[SingleIndexSelector, MultiColSelector] + | tuple[MultiIndexSelector, SingleColSelector] + | tuple[MultiIndexSelector, MultiColSelector] + ), +) -> DataFrame | Series | Any: + """Get part of the DataFrame as a new DataFrame, Series, or scalar.""" + # Two inputs, e.g. df[1, 2:5] + if isinstance(key, tuple) and len(key) == 2: + row_key, col_key = key + selection = _select_columns(df, col_key) + + if selection.is_empty(): + return selection + elif isinstance(selection, pl.Series): + return get_series_item_by_key(selection, row_key) # type: ignore[arg-type] + else: + return _select_rows(selection, row_key) # type: ignore[arg-type] + + # Single input, e.g. df[1] + elif isinstance(key, str): + # This case is required because empty strings are otherwise treated + # as an empty Sequence in `_select_rows` + return df.get_column(key) + elif isinstance(key, Sequence) and len(key) == 0: + # df[[]] + # TODO: This removes all columns, but it should remove all rows. + # https://github.com/pola-rs/polars/issues/4924 + return df.__class__() + try: + return _select_rows(df, key) # type: ignore[arg-type] + except TypeError: + return _select_columns(df, key) # type: ignore[arg-type] + + +# `str` overlaps with `Sequence[str]` +# We can ignore this but we must keep this overload ordering +@overload +def _select_columns(df: DataFrame, key: SingleColSelector) -> Series: ... # type: ignore[overload-overlap] + + +@overload +def _select_columns(df: DataFrame, key: MultiColSelector) -> DataFrame: ... + + +def _select_columns( + df: DataFrame, key: SingleColSelector | MultiColSelector +) -> DataFrame | Series: + """Select one or more columns from the DataFrame.""" + if isinstance(key, int): + return df.to_series(key) + + elif isinstance(key, str): + return df.get_column(key) + + if isinstance(key, slice): + start = key.start + stop = key.stop + if isinstance(start, str): + start = df.get_column_index(start) + if isinstance(stop, str): + stop = df.get_column_index(stop) + 1 + int_slice = slice(start, stop, key.step) + rng = range(df.width)[int_slice] + return _select_columns_by_index(df, rng) + + elif isinstance(key, range): + return _select_columns_by_index(df, key) + + elif isinstance(key, Sequence): + if not key: + return df.__class__() + first = key[0] + if isinstance(first, bool): + return _select_columns_by_mask(df, key) # type: ignore[arg-type] + elif isinstance(first, int): + return _select_columns_by_index(df, key) # type: ignore[arg-type] + elif isinstance(first, str): + return _select_columns_by_name(df, key) # type: ignore[arg-type] + else: + msg = f"cannot select columns using Sequence with elements of type {type(first).__name__!r}" + raise TypeError(msg) + + elif isinstance(key, pl.Series): + if key.is_empty(): + return df.__class__() + dtype = key.dtype + if dtype == String: + return _select_columns_by_name(df, key) + elif dtype.is_integer(): + return _select_columns_by_index(df, key) + elif dtype == Boolean: + return _select_columns_by_mask(df, key) + else: + msg = f"cannot select columns using Series of type {dtype}" + raise TypeError(msg) + + elif _check_for_numpy(key) and isinstance(key, np.ndarray): + if key.ndim != 1: + msg = "multi-dimensional NumPy arrays not supported as index" + raise TypeError(msg) + + if len(key) == 0: + return df.__class__() + + dtype_kind = key.dtype.kind + if dtype_kind in ("i", "u"): + return _select_columns_by_index(df, key) + elif dtype_kind == "b": + return _select_columns_by_mask(df, key) + elif isinstance(key[0], str): + return _select_columns_by_name(df, key) + else: + msg = f"cannot select columns using NumPy array of type {key.dtype}" + raise TypeError(msg) + + msg = f"cannot select columns using key of type {type(key).__name__!r}: {key!r}" + raise TypeError(msg) + + +def _select_columns_by_index(df: DataFrame, key: Iterable[int]) -> DataFrame: + series = [df.to_series(i) for i in key] + return df.__class__(series) + + +def _select_columns_by_name(df: DataFrame, key: Iterable[str]) -> DataFrame: + return df._from_pydf(df._df.select(key)) + + +def _select_columns_by_mask( + df: DataFrame, key: Sequence[bool] | Series | np.ndarray[Any, Any] +) -> DataFrame: + if len(key) != df.width: + msg = f"expected {df.width} values when selecting columns by boolean mask, got {len(key)}" + raise ValueError(msg) + + indices = (i for i, val in enumerate(key) if val) + return _select_columns_by_index(df, indices) + + +@overload +def _select_rows(df: DataFrame, key: SingleIndexSelector) -> Series: ... + + +@overload +def _select_rows(df: DataFrame, key: MultiIndexSelector) -> DataFrame: ... + + +def _select_rows( + df: DataFrame, key: SingleIndexSelector | MultiIndexSelector +) -> DataFrame | Series: + """Select one or more rows from the DataFrame.""" + if isinstance(key, int): + return df.slice(key, 1) + + if isinstance(key, slice): + return _select_rows_by_slice(df, key) + + elif isinstance(key, range): + key = range_to_slice(key) + return _select_rows_by_slice(df, key) + + elif isinstance(key, Sequence): + if not key: + return df.clear() + if isinstance(key[0], bool): + _raise_on_boolean_mask() + s = pl.Series("", key, dtype=Int64) + indices = _convert_series_to_indices(s, df.height) + return _select_rows_by_index(df, indices) + + elif isinstance(key, pl.Series): + indices = _convert_series_to_indices(key, df.height) + return _select_rows_by_index(df, indices) + + elif _check_for_numpy(key) and isinstance(key, np.ndarray): + indices = _convert_np_ndarray_to_indices(key, df.height) + return _select_rows_by_index(df, indices) + + else: + msg = f"cannot select rows using key of type {type(key).__name__!r}: {key!r}" + raise TypeError(msg) + + +def _select_rows_by_slice(df: DataFrame, key: slice) -> DataFrame: + return PolarsSlice(df).apply(key) # type: ignore[return-value] + + +def _select_rows_by_index(df: DataFrame, key: Series) -> DataFrame: + return df._from_pydf(df._df.gather_with_series(key._s)) + + +# UTILS + + +def _convert_series_to_indices(s: Series, size: int) -> Series: + """Convert a Series to indices, taking into account negative values.""" + # Unsigned or signed Series (ordered from fastest to slowest). + # - pl.UInt32 (polars) or pl.UInt64 (polars_u64_idx) Series indexes. + # - Other unsigned Series indexes are converted to pl.UInt32 (polars) + # or pl.UInt64 (polars_u64_idx). + # - Signed Series indexes are converted pl.UInt32 (polars) or + # pl.UInt64 (polars_u64_idx) after negative indexes are converted + # to absolute indexes. + + # pl.UInt32 (polars) or pl.UInt64 (polars_u64_idx). + idx_type = get_index_type() + + if s.dtype == idx_type: + return s + + if not s.dtype.is_integer(): + if s.dtype == Boolean: + _raise_on_boolean_mask() + else: + msg = f"cannot treat Series of type {s.dtype} as indices" + raise TypeError(msg) + + if s.len() == 0: + return pl.Series(s.name, [], dtype=idx_type) + + if idx_type == UInt32: + if s.dtype in {Int64, UInt64} and s.max() >= U32_MAX: # type: ignore[operator] + msg = "index positions should be smaller than 2^32" + raise ValueError(msg) + if s.dtype == Int64 and s.min() < -U32_MAX: # type: ignore[operator] + msg = "index positions should be greater than or equal to -2^32" + raise ValueError(msg) + + if s.dtype.is_signed_integer(): + if s.min() < 0: # type: ignore[operator] + if idx_type == UInt32: + idxs = s.cast(Int32) if s.dtype in {Int8, Int16} else s + else: + idxs = s.cast(Int64) if s.dtype in {Int8, Int16, Int32} else s + + # Update negative indexes to absolute indexes. + return ( + idxs.to_frame() + .select( + F.when(F.col(idxs.name) < 0) + .then(size + F.col(idxs.name)) + .otherwise(F.col(idxs.name)) + .cast(idx_type) + ) + .to_series(0) + ) + + return s.cast(idx_type) + + +def _convert_np_ndarray_to_indices(arr: np.ndarray[Any, Any], size: int) -> Series: + """Convert a NumPy ndarray to indices, taking into account negative values.""" + # Unsigned or signed Numpy array (ordered from fastest to slowest). + # - np.uint32 (polars) or np.uint64 (polars_u64_idx) numpy array + # indexes. + # - Other unsigned numpy array indexes are converted to pl.UInt32 + # (polars) or pl.UInt64 (polars_u64_idx). + # - Signed numpy array indexes are converted pl.UInt32 (polars) or + # pl.UInt64 (polars_u64_idx) after negative indexes are converted + # to absolute indexes. + if arr.ndim != 1: + msg = "only 1D NumPy arrays can be treated as indices" + raise TypeError(msg) + + idx_type = get_index_type() + + if len(arr) == 0: + return pl.Series("", [], dtype=idx_type) + + # Numpy array with signed or unsigned integers. + if arr.dtype.kind not in ("i", "u"): + if arr.dtype.kind == "b": + _raise_on_boolean_mask() + else: + msg = f"cannot treat NumPy array of type {arr.dtype} as indices" + raise TypeError(msg) + + if idx_type == UInt32: + if arr.dtype in {np.int64, np.uint64} and arr.max() >= U32_MAX: + msg = "index positions should be smaller than 2^32" + raise ValueError(msg) + if arr.dtype == np.int64 and arr.min() < -U32_MAX: + msg = "index positions should be greater than or equal to -2^32" + raise ValueError(msg) + + if arr.dtype.kind == "i" and arr.min() < 0: + if idx_type == UInt32: + if arr.dtype in (np.int8, np.int16): + arr = arr.astype(np.int32) + else: + if arr.dtype in (np.int8, np.int16, np.int32): + arr = arr.astype(np.int64) + + # Update negative indexes to absolute indexes. + arr = np.where(arr < 0, size + arr, arr) + + # numpy conversion is much faster + arr = arr.astype(np.uint32) if idx_type == UInt32 else arr.astype(np.uint64) + + return pl.Series("", arr, dtype=idx_type) + + +def _raise_on_boolean_mask() -> NoReturn: + msg = ( + "selecting rows by passing a boolean mask to `__getitem__` is not supported" + "\n\nHint: Use the `filter` method instead." + ) + raise TypeError(msg) diff --git a/py-polars/polars/dataframe/frame.py b/py-polars/polars/dataframe/frame.py index 0d40bba318f6..988015a05627 100644 --- a/py-polars/polars/dataframe/frame.py +++ b/py-polars/polars/dataframe/frame.py @@ -24,7 +24,6 @@ NoReturn, Sequence, TypeVar, - Union, cast, get_args, overload, @@ -37,7 +36,6 @@ dataframe_to_pydf, dict_to_pydf, iterable_to_pydf, - numpy_to_idxs, numpy_to_pydf, pandas_to_pydf, sequence_to_pydf, @@ -53,15 +51,13 @@ deprecate_saturating, issue_deprecation_warning, ) +from polars._utils.getitem import get_df_item_by_key from polars._utils.parse_expr_input import parse_as_expression from polars._utils.unstable import issue_unstable_warning, unstable from polars._utils.various import ( is_bool_sequence, - is_int_sequence, - is_str_sequence, normalize_filepath, parse_version, - range_to_slice, scale_bytes, warn_null_comparison, ) @@ -103,7 +99,6 @@ ) from polars.functions import col, lit from polars.selectors import _expand_selector_dicts, _expand_selectors -from polars.slice import PolarsSlice from polars.type_aliases import DbWriteMode, JaxExportType, TorchExportType with contextlib.suppress(ImportError): # Module not available when building docs @@ -149,6 +144,8 @@ JoinStrategy, JoinValidation, Label, + MultiColSelector, + MultiIndexSelector, NullStrategy, OneOrMoreDataTypes, Orientation, @@ -160,6 +157,8 @@ SchemaDefinition, SchemaDict, SelectorType, + SingleColSelector, + SingleIndexSelector, SizeUnit, StartBy, UniqueKeepStrategy, @@ -167,25 +166,15 @@ ) if sys.version_info >= (3, 10): - from typing import Concatenate, ParamSpec, TypeAlias + from typing import Concatenate, ParamSpec else: - from typing_extensions import Concatenate, ParamSpec, TypeAlias + from typing_extensions import Concatenate, ParamSpec if sys.version_info >= (3, 11): from typing import Self else: from typing_extensions import Self - # these aliases are used to annotate DataFrame.__getitem__() - # MultiRowSelector indexes into the vertical axis and - # MultiColSelector indexes into the horizontal axis - # NOTE: wrapping these as strings is necessary for Python <3.10 - - MultiRowSelector: TypeAlias = Union[slice, range, "list[int]", "Series"] - MultiColSelector: TypeAlias = Union[ - slice, range, "list[int]", "list[str]", "list[bool]", "Series" - ] - T = TypeVar("T") P = ParamSpec("P") @@ -1007,199 +996,45 @@ def __iter__(self) -> Iterator[Series]: def __reversed__(self) -> Iterator[Series]: return reversed(self.get_columns()) - def _pos_idx(self, idx: int, dim: int) -> int: - if idx >= 0: - return idx - else: - return self.shape[dim] + idx - - def _take_with_series(self, s: Series) -> DataFrame: - return self._from_pydf(self._df.take_with_series(s._s)) + # `str` overlaps with `Sequence[str]` + # We can ignore this but we must keep this overload ordering + @overload + def __getitem__( + self, key: tuple[SingleIndexSelector, SingleColSelector] + ) -> Any: ... @overload - def __getitem__(self, item: str) -> Series: ... + def __getitem__( # type: ignore[overload-overlap] + self, key: str | tuple[MultiIndexSelector, SingleColSelector] + ) -> Series: ... @overload def __getitem__( self, - item: ( - int - | np.ndarray[Any, Any] + key: ( + SingleIndexSelector + | MultiIndexSelector | MultiColSelector - | tuple[int, MultiColSelector] - | tuple[MultiRowSelector, MultiColSelector] + | tuple[SingleIndexSelector, MultiColSelector] + | tuple[MultiIndexSelector, MultiColSelector] ), - ) -> Self: ... - - @overload - def __getitem__(self, item: tuple[int, int | str]) -> Any: ... - - @overload - def __getitem__(self, item: tuple[MultiRowSelector, int | str]) -> Series: ... + ) -> DataFrame: ... def __getitem__( self, - item: ( - str - | int - | np.ndarray[Any, Any] + key: ( + SingleIndexSelector + | SingleColSelector | MultiColSelector - | tuple[int, MultiColSelector] - | tuple[MultiRowSelector, MultiColSelector] - | tuple[MultiRowSelector, int | str] - | tuple[int, int | str] + | MultiIndexSelector + | tuple[SingleIndexSelector, SingleColSelector] + | tuple[SingleIndexSelector, MultiColSelector] + | tuple[MultiIndexSelector, SingleColSelector] + | tuple[MultiIndexSelector, MultiColSelector] ), - ) -> DataFrame | Series: - """Get item. Does quite a lot. Read the comments.""" - # fail on ['col1', 'col2', ..., 'coln'] - if ( - isinstance(item, tuple) - and len(item) > 1 # type: ignore[redundant-expr] - and all(isinstance(x, str) for x in item) - ): - raise KeyError(item) - - # select rows and columns at once - # every 2d selection, i.e. tuple is row column order, just like numpy - if isinstance(item, tuple) and len(item) == 2: - row_selection, col_selection = item - - # df[[], :] - if isinstance(row_selection, Sequence): - if len(row_selection) == 0: - # handle empty list by falling through to slice - row_selection = slice(0) - - # df[:, unknown] - if isinstance(row_selection, slice): - # multiple slices - # df[:, :] - if isinstance(col_selection, slice): - # slice can be - # by index - # [1:8] - # or by column name - # ["foo":"bar"] - # first we make sure that the slice is by index - start = col_selection.start - stop = col_selection.stop - if isinstance(col_selection.start, str): - start = self.get_column_index(col_selection.start) - if isinstance(col_selection.stop, str): - stop = self.get_column_index(col_selection.stop) + 1 - - col_selection = slice(start, stop, col_selection.step) - - df = self.__getitem__(self.columns[col_selection]) - return df[row_selection] - - # df[:, [True, False]] - if is_bool_sequence(col_selection) or ( - isinstance(col_selection, pl.Series) - and col_selection.dtype == Boolean - ): - if len(col_selection) != self.width: - msg = ( - f"expected {self.width} values when selecting columns by" - f" boolean mask, got {len(col_selection)}" - ) - raise ValueError(msg) - series_list = [] - for i, val in enumerate(col_selection): - if val: - series_list.append(self.to_series(i)) - - df = self.__class__(series_list) - return df[row_selection] - - # df[2, :] (select row as df) - if isinstance(row_selection, int): - if isinstance(col_selection, (slice, list)) or ( - _check_for_numpy(col_selection) - and isinstance(col_selection, np.ndarray) - ): - df = self[:, col_selection] - return df.slice(row_selection, 1) - - # df[:, "a"] - if isinstance(col_selection, str): - series = self.get_column(col_selection) - return series[row_selection] - - # df[:, 1] - if isinstance(col_selection, int): - if (col_selection >= 0 and col_selection >= self.width) or ( - col_selection < 0 and col_selection < -self.width - ): - msg = f"column index {col_selection!r} is out of bounds" - raise IndexError(msg) - series = self.to_series(col_selection) - return series[row_selection] - - if isinstance(col_selection, list): - # df[:, [1, 2]] - if is_int_sequence(col_selection): - for i in col_selection: - if (i >= 0 and i >= self.width) or (i < 0 and i < -self.width): - msg = f"column index {col_selection!r} is out of bounds" - raise IndexError(msg) - series_list = [self.to_series(i) for i in col_selection] - df = self.__class__(series_list) - return df[row_selection] - - df = self.__getitem__(col_selection) - return df.__getitem__(row_selection) - - # select single column - # df["foo"] - if isinstance(item, str): - return self.get_column(item) - - # df[idx] - if isinstance(item, int): - return self.slice(self._pos_idx(item, dim=0), 1) - - # df[range(n)] - if isinstance(item, range): - return self[range_to_slice(item)] - - # df[:] - if isinstance(item, slice): - return PolarsSlice(self).apply(item) - - # select rows by numpy mask or index - # df[np.array([1, 2, 3])] - # df[np.array([True, False, True])] - if _check_for_numpy(item) and isinstance(item, np.ndarray): - if item.ndim != 1: - msg = "multi-dimensional NumPy arrays not supported as index" - raise TypeError(msg) - if item.dtype.kind in ("i", "u"): - # Numpy array with signed or unsigned integers. - return self._take_with_series(numpy_to_idxs(item, self.shape[0])) - if isinstance(item[0], str): - return self._from_pydf(self._df.select(item)) - - if is_str_sequence(item, allow_str=False): - # select multiple columns - # df[["foo", "bar"]] - return self._from_pydf(self._df.select(item)) - elif is_int_sequence(item): - item = pl.Series("", item) # fall through to next if isinstance - - if isinstance(item, pl.Series): - dtype = item.dtype - if dtype == String: - return self._from_pydf(self._df.select(item)) - elif dtype.is_integer(): - return self._take_with_series(item._pos_idxs(self.shape[0])) - - # if no data has been returned, the operation is not supported - msg = ( - f"cannot use `__getitem__` on DataFrame with item {item!r}" - f" of type {type(item).__name__!r}" - ) - raise TypeError(msg) + ) -> DataFrame | Series | Any: + """Get part of the DataFrame as a new DataFrame, Series, or scalar.""" + return get_df_item_by_key(self, key) def __setitem__( self, @@ -1348,20 +1183,17 @@ def item(self, row: int | None = None, column: int | str | None = None) -> Any: f" frame has shape {self.shape!r}" ) raise ValueError(msg) - return self._df.select_at_idx(0).get_index(0) + return self._df.to_series(0).get_index(0) elif row is None or column is None: msg = "cannot call `.item()` with only one of `row` or `column`" raise ValueError(msg) s = ( - self._df.select_at_idx(column) + self._df.to_series(column) if isinstance(column, int) else self._df.get_column(column) ) - if s is None: - msg = f"column index {column!r} is out of bounds" - raise IndexError(msg) return s.get_index_signed(row) def to_arrow(self) -> pa.Table: @@ -2283,13 +2115,7 @@ def to_series(self, index: int = 0) -> Series: 8 ] """ - if not isinstance(index, int): - msg = f"index value {index!r} should be an int, but is {type(index).__name__!r}" - raise TypeError(msg) - - if index < 0: - index = len(self.columns) + index - return wrap_s(self._df.select_at_idx(index)) + return wrap_s(self._df.to_series(index)) def to_init_repr(self, n: int = 1000) -> str: """ @@ -8166,6 +7992,7 @@ def partition_by( f" Pass `by` as a list to silence this warning, e.g. `partition_by([{by!r}], as_dict=True)`.", version="0.20.4", ) + if include_key: if key_as_single_value: names = [p.get_column(by)[0] for p in partitions] # type: ignore[arg-type] diff --git a/py-polars/polars/series/series.py b/py-polars/polars/series/series.py index 20c615b5b46a..a515dc6ae983 100644 --- a/py-polars/polars/series/series.py +++ b/py-polars/polars/series/series.py @@ -28,7 +28,6 @@ arrow_to_pyseries, dataframe_to_pyseries, iterable_to_pyseries, - numpy_to_idxs, numpy_to_pyseries, pandas_to_pyseries, sequence_to_pyseries, @@ -47,13 +46,13 @@ deprecate_renamed_parameter, issue_deprecation_warning, ) +from polars._utils.getitem import get_series_item_by_key from polars._utils.unstable import unstable from polars._utils.various import ( BUILDING_SPHINX_DOCS, _is_generator, no_default, parse_version, - range_to_slice, scale_bytes, sphinx_accessor, warn_null_comparison, @@ -70,8 +69,6 @@ Enum, Float32, Float64, - Int8, - Int16, Int32, Int64, List, @@ -103,7 +100,6 @@ from polars.dependencies import pandas as pd from polars.dependencies import pyarrow as pa from polars.exceptions import ComputeError, ModuleUpgradeRequired, ShapeError -from polars.meta import get_index_type from polars.series.array import ArrayNameSpace from polars.series.binary import BinaryNameSpace from polars.series.categorical import CatNameSpace @@ -112,7 +108,6 @@ from polars.series.string import StringNameSpace from polars.series.struct import StructNameSpace from polars.series.utils import expr_dispatch, get_ffi_func -from polars.slice import PolarsSlice with contextlib.suppress(ImportError): # Module not available when building docs from polars.polars import PyDataFrame, PySeries @@ -138,6 +133,7 @@ InterpolationMethod, IntoExpr, IntoExprColumn, + MultiIndexSelector, NonNestedLiteral, NullBehavior, NumericLiteral, @@ -148,6 +144,7 @@ RollingInterpolationMethod, SearchSortedSide, SeriesBuffers, + SingleIndexSelector, SizeUnit, TemporalLiteral, ) @@ -1231,110 +1228,17 @@ def __iter__(self) -> Generator[Any, None, None]: for offset in range(0, self.len(), buffer_size): yield from self.slice(offset, buffer_size).to_list() - def _pos_idxs(self, size: int) -> Series: - # Unsigned or signed Series (ordered from fastest to slowest). - # - pl.UInt32 (polars) or pl.UInt64 (polars_u64_idx) Series indexes. - # - Other unsigned Series indexes are converted to pl.UInt32 (polars) - # or pl.UInt64 (polars_u64_idx). - # - Signed Series indexes are converted pl.UInt32 (polars) or - # pl.UInt64 (polars_u64_idx) after negative indexes are converted - # to absolute indexes. - - # pl.UInt32 (polars) or pl.UInt64 (polars_u64_idx). - idx_type = get_index_type() - - if self.dtype == idx_type: - return self - - if not self.dtype.is_integer(): - msg = "unsupported idxs datatype" - raise NotImplementedError(msg) - - if self.len() == 0: - return Series(self.name, [], dtype=idx_type) - - if idx_type == UInt32: - if self.dtype in {Int64, UInt64}: - if self.max() >= 2**32: # type: ignore[operator] - msg = "index positions should be smaller than 2^32" - raise ValueError(msg) - if self.dtype == Int64: - if self.min() < -(2**32): # type: ignore[operator] - msg = "index positions should be bigger than -2^32 + 1" - raise ValueError(msg) - - if self.dtype.is_signed_integer(): - if self.min() < 0: # type: ignore[operator] - if idx_type == UInt32: - idxs = self.cast(Int32) if self.dtype in {Int8, Int16} else self - else: - idxs = ( - self.cast(Int64) if self.dtype in {Int8, Int16, Int32} else self - ) - - # Update negative indexes to absolute indexes. - return ( - idxs.to_frame() - .select( - F.when(F.col(idxs.name) < 0) - .then(size + F.col(idxs.name)) - .otherwise(F.col(idxs.name)) - .cast(idx_type) - ) - .to_series(0) - ) - - return self.cast(idx_type) - - def _take_with_series(self, s: Series) -> Series: - return self._from_pyseries(self._s.take_with_series(s._s)) - @overload - def __getitem__(self, item: int) -> Any: ... + def __getitem__(self, key: SingleIndexSelector) -> Any: ... @overload - def __getitem__( - self, - item: Series | range | slice | np.ndarray[Any, Any] | list[int], - ) -> Series: ... + def __getitem__(self, key: MultiIndexSelector) -> Series: ... def __getitem__( - self, - item: int | Series | range | slice | np.ndarray[Any, Any] | list[int], - ) -> Any: - if isinstance(item, Series) and item.dtype.is_integer(): - return self._take_with_series(item._pos_idxs(self.len())) - - elif _check_for_numpy(item) and isinstance(item, np.ndarray): - return self._take_with_series(numpy_to_idxs(item, self.len())) - - # Integer - elif isinstance(item, int): - return self._s.get_index_signed(item) - - # Slice - elif isinstance(item, slice): - return PolarsSlice(self).apply(item) - - # Range - elif isinstance(item, range): - return self[range_to_slice(item)] - - # Sequence of integers (also triggers on empty sequence) - elif isinstance(item, Sequence) and ( - not item or (isinstance(item[0], int) and not isinstance(item[0], bool)) # type: ignore[redundant-expr] - ): - idx_series = Series("", item, dtype=Int64)._pos_idxs(self.len()) - if idx_series.has_nulls(): - msg = "cannot use `__getitem__` with index values containing nulls" - raise ValueError(msg) - return self._take_with_series(idx_series) - - msg = ( - f"cannot use `__getitem__` on Series of dtype {self.dtype!r}" - f" with argument {item!r} of type {type(item).__name__!r}" - ) - raise TypeError(msg) + self, key: SingleIndexSelector | MultiIndexSelector + ) -> Any | Series: + """Get part of the Series as a new Series or scalar.""" + return get_series_item_by_key(self, key) def __setitem__( self, diff --git a/py-polars/polars/type_aliases.py b/py-polars/polars/type_aliases.py index 924afb0f306e..78e4a5de2267 100644 --- a/py-polars/polars/type_aliases.py +++ b/py-polars/polars/type_aliases.py @@ -249,5 +249,32 @@ def fetchmany(self, *args: Any, **kwargs: Any) -> Any: """Fetch results in batches.""" -AlchemyConnection = Union["Connection", "Engine", "Session"] -ConnectionOrCursor = Union[BasicConnection, BasicCursor, Cursor, AlchemyConnection] +AlchemyConnection: TypeAlias = Union["Connection", "Engine", "Session"] +ConnectionOrCursor: TypeAlias = Union[ + BasicConnection, BasicCursor, Cursor, AlchemyConnection +] + + +# Annotations for `__getitem__` methods +SingleIndexSelector: TypeAlias = int +MultiIndexSelector: TypeAlias = Union[ + slice, + range, + Sequence[int], + "Series", + "np.ndarray[Any, Any]", +] +SingleNameSelector: TypeAlias = str +MultiNameSelector: TypeAlias = Union[ + slice, + Sequence[str], + "Series", + "np.ndarray[Any, Any]", +] +BooleanMask: TypeAlias = Union[ + Sequence[bool], + "Series", + "np.ndarray[Any, Any]", +] +SingleColSelector: TypeAlias = Union[SingleIndexSelector, SingleNameSelector] +MultiColSelector: TypeAlias = Union[MultiIndexSelector, MultiNameSelector, BooleanMask] diff --git a/py-polars/src/dataframe/general.rs b/py-polars/src/dataframe/general.rs index dfc943ddb2c1..af96d67eb66a 100644 --- a/py-polars/src/dataframe/general.rs +++ b/py-polars/src/dataframe/general.rs @@ -5,6 +5,7 @@ use polars::prelude::*; use polars_core::frame::*; #[cfg(feature = "pivot")] use polars_lazy::frame::pivot::{pivot, pivot_stable}; +use pyo3::exceptions::PyIndexError; use pyo3::prelude::*; use pyo3::pybacked::PyBackedStr; use pyo3::types::{PyBytes, PyList}; @@ -237,8 +238,22 @@ impl PyDataFrame { Ok(PySeries { series: s }) } - pub fn select_at_idx(&self, idx: usize) -> Option { - self.df.select_at_idx(idx).map(|s| PySeries::new(s.clone())) + pub fn to_series(&self, index: isize) -> PyResult { + let df = &self.df; + + let index_adjusted = if index < 0 { + df.width().checked_sub(index.unsigned_abs()) + } else { + Some(usize::try_from(index).unwrap()) + }; + + let s = index_adjusted.and_then(|i| df.select_at_idx(i)); + match s { + Some(s) => Ok(PySeries::new(s.clone())), + None => Err(PyIndexError::new_err( + polars_err!(oob = index, df.width()).to_string(), + )), + } } pub fn get_column_index(&self, name: &str) -> Option { @@ -254,8 +269,8 @@ impl PyDataFrame { Ok(series) } - pub fn select(&self, selection: Vec) -> PyResult { - let df = self.df.select(selection).map_err(PyPolarsErr::from)?; + pub fn select(&self, columns: Vec) -> PyResult { + let df = self.df.select(columns).map_err(PyPolarsErr::from)?; Ok(PyDataFrame::new(df)) } @@ -266,9 +281,9 @@ impl PyDataFrame { Ok(PyDataFrame::new(df)) } - pub fn take_with_series(&self, indices: &PySeries) -> PyResult { - let idx = indices.series.idx().map_err(PyPolarsErr::from)?; - let df = self.df.take(idx).map_err(PyPolarsErr::from)?; + pub fn gather_with_series(&self, indices: &PySeries) -> PyResult { + let indices = indices.series.idx().map_err(PyPolarsErr::from)?; + let df = self.df.take(indices).map_err(PyPolarsErr::from)?; Ok(PyDataFrame::new(df)) } diff --git a/py-polars/src/series/mod.rs b/py-polars/src/series/mod.rs index 7e34dfa91ff2..cf1b7dc03a9b 100644 --- a/py-polars/src/series/mod.rs +++ b/py-polars/src/series/mod.rs @@ -169,6 +169,7 @@ impl PySeries { } } + /// Get a value by index. fn get_index(&self, py: Python, index: usize) -> PyResult { let av = match self.series.get(index) { Ok(v) => v, @@ -194,10 +195,10 @@ impl PySeries { Ok(out) } - /// Get index but allow negative indices - fn get_index_signed(&self, py: Python, index: i64) -> PyResult { + /// Get a value by index, allowing negative indices. + fn get_index_signed(&self, py: Python, index: isize) -> PyResult { let index = if index < 0 { - match self.len().checked_sub(index.unsigned_abs() as usize) { + match self.len().checked_sub(index.unsigned_abs()) { Some(v) => v, None => { return Err(PyIndexError::new_err( @@ -206,7 +207,7 @@ impl PySeries { }, } } else { - index as usize + usize::try_from(index).unwrap() }; self.get_index(py, index) } @@ -309,10 +310,10 @@ impl PySeries { .into()) } - fn take_with_series(&self, indices: &PySeries) -> PyResult { - let idx = indices.series.idx().map_err(PyPolarsErr::from)?; - let take = self.series.take(idx).map_err(PyPolarsErr::from)?; - Ok(take.into()) + fn gather_with_series(&self, indices: &PySeries) -> PyResult { + let indices = indices.series.idx().map_err(PyPolarsErr::from)?; + let s = self.series.take(indices).map_err(PyPolarsErr::from)?; + Ok(s.into()) } fn null_count(&self) -> PyResult { diff --git a/py-polars/tests/unit/dataframe/test_df.py b/py-polars/tests/unit/dataframe/test_df.py index ce36e51a45e2..112bea712a91 100644 --- a/py-polars/tests/unit/dataframe/test_df.py +++ b/py-polars/tests/unit/dataframe/test_df.py @@ -168,8 +168,9 @@ def test_selection() -> None: assert df.to_series(0).name == "a" assert (df["a"] == df["a"]).sum() == 3 assert (df["c"] == df["a"].cast(str)).sum() == 0 - assert df[:, "a":"b"].rows() == [(1, 1.0), (2, 2.0), (3, 3.0)] # type: ignore[misc] - assert df[:, "a":"c"].columns == ["a", "b", "c"] # type: ignore[misc] + assert df[:, "a":"b"].rows() == [(1, 1.0), (2, 2.0), (3, 3.0)] # type: ignore[index, misc] + assert df[:, "a":"c"].columns == ["a", "b", "c"] # type: ignore[index, misc] + assert df[:, []].shape == (0, 0) expect = pl.DataFrame({"c": ["b"]}) assert_frame_equal(df[1, [2]], expect) expect = pl.DataFrame({"b": [1.0, 3.0]}) @@ -413,7 +414,19 @@ def test_to_series() -> None: assert_series_equal(df.to_series(2), df["z"]) assert_series_equal(df.to_series(-1), df["z"]) - with pytest.raises(TypeError, match="should be an int"): + +def test_to_series_bad_inputs() -> None: + df = pl.DataFrame({"x": [1, 2, 3], "y": [2, 3, 4], "z": [3, 4, 5]}) + + with pytest.raises(IndexError, match="index 5 is out of bounds"): + df.to_series(5) + + with pytest.raises(IndexError, match="index -100 is out of bounds"): + df.to_series(-100) + + with pytest.raises( + TypeError, match="'str' object cannot be interpreted as an integer" + ): df.to_series("x") # type: ignore[arg-type] @@ -1985,181 +1998,6 @@ def test_add_string() -> None: assert_frame_equal(("hello " + df), expected) -def test_getitem() -> None: - """Test all the methods to use [] on a dataframe.""" - df = pl.DataFrame({"a": [1.0, 2.0, 3.0, 4.0], "b": [3, 4, 5, 6]}) - - # expression - assert_frame_equal( - df.select(pl.col("a")), pl.DataFrame({"a": [1.0, 2.0, 3.0, 4.0]}) - ) - - # multiple slices. - # The first element refers to the rows, the second element to columns - assert_frame_equal(df[:, :], df) - - # str, always refers to a column name - assert_series_equal(df["a"], pl.Series("a", [1.0, 2.0, 3.0, 4.0])) - - # int, always refers to a row index (zero-based): index=1 => second row - assert_frame_equal(df[1], pl.DataFrame({"a": [2.0], "b": [4]})) - - # int, int. - # The first element refers to the rows, the second element to columns - assert df[2, 1] == 5 - assert df[2, -2] == 3.0 - - with pytest.raises(IndexError): - # Column index out of bounds - df[2, 2] - - with pytest.raises(IndexError): - # Column index out of bounds - df[2, -3] - - # int, list[int]. - # The first element refers to the rows, the second element to columns - assert_frame_equal(df[2, [1, 0]], pl.DataFrame({"b": [5], "a": [3.0]})) - assert_frame_equal(df[2, [-1, -2]], pl.DataFrame({"b": [5], "a": [3.0]})) - - with pytest.raises(IndexError): - # Column index out of bounds - df[2, [2, 0]] - - with pytest.raises(IndexError): - # Column index out of bounds - df[2, [2, -3]] - - # range, refers to rows - assert_frame_equal(df[range(1, 3)], pl.DataFrame({"a": [2.0, 3.0], "b": [4, 5]})) - - # slice. Below an example of taking every second row - assert_frame_equal(df[1::2], pl.DataFrame({"a": [2.0, 4.0], "b": [4, 6]})) - - # slice, empty slice - assert df[:0].columns == ["a", "b"] - assert len(df[:0]) == 0 - - # make mypy happy - empty: list[int] = [] - - # empty list with column selector drops rows but keeps columns - assert_frame_equal(df[empty, :], df[:0]) - - # empty list without column select return empty frame - assert_frame_equal(df[empty], pl.DataFrame({})) - - # numpy array: assumed to be row indices if integers, or columns if strings - - # numpy array: positive idxs and empty idx - for np_dtype in ( - np.int8, - np.int16, - np.int32, - np.int64, - np.uint8, - np.uint16, - np.uint32, - np.uint64, - ): - assert_frame_equal( - df[np.array([1, 0, 3, 2, 3, 0], dtype=np_dtype)], - pl.DataFrame( - {"a": [2.0, 1.0, 4.0, 3.0, 4.0, 1.0], "b": [4, 3, 6, 5, 6, 3]} - ), - ) - assert df[np.array([], dtype=np_dtype)].columns == ["a", "b"] - - # numpy array: positive and negative idxs. - for np_dtype in (np.int8, np.int16, np.int32, np.int64): - assert_frame_equal( - df[np.array([-1, 0, -3, -2, 3, -4], dtype=np_dtype)], - pl.DataFrame( - {"a": [4.0, 1.0, 2.0, 3.0, 4.0, 1.0], "b": [6, 3, 4, 5, 6, 3]} - ), - ) - - # note that we cannot use floats (even if they could be casted to integer without - # loss) - with pytest.raises(TypeError): - _ = df[np.array([1.0])] - - with pytest.raises( - TypeError, - match="multi-dimensional NumPy arrays not supported", - ): - df[np.array([[0], [1]])] - - # sequences (lists or tuples; tuple only if length != 2) - # if strings or list of expressions, assumed to be column names - # if bools, assumed to be a row mask - # if integers, assumed to be row indices - assert_frame_equal(df[["a", "b"]], df) - assert_frame_equal(df.select([pl.col("a"), pl.col("b")]), df) - assert_frame_equal( - df[[1, -4, -1, 2, 1]], - pl.DataFrame({"a": [2.0, 1.0, 4.0, 3.0, 2.0], "b": [4, 3, 6, 5, 4]}), - ) - - # pl.Series: strings for column selections. - assert_frame_equal(df[pl.Series("", ["a", "b"])], df) - - # pl.Series: positive idxs or empty idxs for row selection. - for pl_dtype in ( - pl.Int8, - pl.Int16, - pl.Int32, - pl.Int64, - pl.UInt8, - pl.UInt16, - pl.UInt32, - pl.UInt64, - ): - assert_frame_equal( - df[pl.Series("", [1, 0, 3, 2, 3, 0], dtype=pl_dtype)], - pl.DataFrame( - {"a": [2.0, 1.0, 4.0, 3.0, 4.0, 1.0], "b": [4, 3, 6, 5, 6, 3]} - ), - ) - assert df[pl.Series("", [], dtype=pl_dtype)].columns == ["a", "b"] - - # pl.Series: positive and negative idxs for row selection. - for pl_dtype in (pl.Int8, pl.Int16, pl.Int32, pl.Int64): - assert_frame_equal( - df[pl.Series("", [-1, 0, -3, -2, 3, -4], dtype=pl_dtype)], - pl.DataFrame( - {"a": [4.0, 1.0, 2.0, 3.0, 4.0, 1.0], "b": [6, 3, 4, 5, 6, 3]} - ), - ) - - # Boolean masks not supported - with pytest.raises(TypeError): - df[np.array([True, False, True])] - with pytest.raises(TypeError): - df[[True, False, True], [False, True]] # type: ignore[index] - with pytest.raises(TypeError): - df[pl.Series([True, False, True]), "b"] - - # wrong length boolean mask for column selection - with pytest.raises( - ValueError, - match=f"expected {df.width} values when selecting columns by boolean mask", - ): - df[:, [True, False, True]] - - # 5343 - df = pl.DataFrame( - { - f"foo{col}": [n**col for n in range(5)] # 5 rows - for col in range(12) # 12 columns - } - ) - assert df[4, 4] == 256 - assert df[4, 5] == 1024 - assert_frame_equal(df[4, [2]], pl.DataFrame({"foo2": [16]})) - assert_frame_equal(df[4, [5]], pl.DataFrame({"foo5": [1024]})) - - def test_df_broadcast() -> None: df = pl.DataFrame({"a": [1, 2, 3]}, schema_overrides={"a": pl.UInt8}) out = df.with_columns(pl.Series("s", [[1, 2]])) diff --git a/py-polars/tests/unit/dataframe/test_getitem.py b/py-polars/tests/unit/dataframe/test_getitem.py index 1e4dd95fed3e..670236a71441 100644 --- a/py-polars/tests/unit/dataframe/test_getitem.py +++ b/py-polars/tests/unit/dataframe/test_getitem.py @@ -1,9 +1,14 @@ from __future__ import annotations +from typing import Any + import hypothesis.strategies as st +import numpy as np +import pytest from hypothesis import given import polars as pl +from polars.testing import assert_frame_equal, assert_series_equal from polars.testing.parametric import column, dataframes @@ -45,7 +50,7 @@ # │ null ┆ 1 ┆ 2 ┆ 5865 │ # └───────┴──────┴──────┴───────┘ ) -def test_frame_slice(df: pl.DataFrame) -> None: +def test_df_getitem_row_slice(df: pl.DataFrame) -> None: # take strategy-generated integer values from the frame as slice bounds. # use these bounds to slice the same frame, and then validate the result # against a py-native slice of the same data using the same bounds. @@ -63,3 +68,328 @@ def test_frame_slice(df: pl.DataFrame) -> None: assert ( sliced_py_data == sliced_df_data ), f"slice [{start}:{stop}:{step}] failed on df w/len={len(df)}" + + +def test_df_getitem_col_single_name() -> None: + df = pl.DataFrame({"a": [1, 2], "b": [3, 4]}) + result = df[:, "a"] + expected = df.select("a").to_series() + assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + ("input", "expected_cols"), + [ + (["a"], ["a"]), + (["a", "d"], ["a", "d"]), + (slice("b", "d"), ["b", "c", "d"]), + (pl.Series(["a", "b"]), ["a", "b"]), + (np.array(["c", "d"]), ["c", "d"]), + ], +) +def test_df_getitem_col_multiple_names(input: Any, expected_cols: list[str]) -> None: + df = pl.DataFrame({"a": [1, 2], "b": [3, 4], "c": [5, 6], "d": [7, 8]}) + result = df[:, input] + expected = df.select(expected_cols) + assert_frame_equal(result, expected) + + +def test_df_getitem_col_single_index() -> None: + df = pl.DataFrame({"a": [1, 2], "b": [3, 4]}) + result = df[:, 1] + expected = df.select("b").to_series() + assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + ("input", "expected_cols"), + [ + ([0], ["a"]), + ([0, 3], ["a", "d"]), + (slice(1, 4), ["b", "c", "d"]), + (pl.Series([0, 1]), ["a", "b"]), + (np.array([2, 3]), ["c", "d"]), + ], +) +def test_df_getitem_col_multiple_indices(input: Any, expected_cols: list[str]) -> None: + df = pl.DataFrame({"a": [1, 2], "b": [3, 4], "c": [5, 6], "d": [7, 8]}) + result = df[:, input] + expected = df.select(expected_cols) + assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "mask", + [ + [True, False, True], + pl.Series([True, False, True]), + np.array([True, False, True]), + ], +) +def test_df_getitem_col_boolean_mask(mask: Any) -> None: + df = pl.DataFrame({"a": [1, 2], "b": [3, 4], "c": [5, 6]}) + result = df[:, mask] + expected = df.select("a", "c") + assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + ("rng", "expected_cols"), + [ + (range(2), ["a", "b"]), + (range(1, 4), ["b", "c", "d"]), + (range(3, 0, -2), ["d", "b"]), + ], +) +def test_df_getitem_col_range(rng: range, expected_cols: list[str]) -> None: + df = pl.DataFrame({"a": [1, 2], "b": [3, 4], "c": [5, 6], "d": [7, 8]}) + result = df[:, rng] + expected = df.select(expected_cols) + assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + "input", [[], (), pl.Series(dtype=pl.Int64), np.array([], dtype=np.uint32)] +) +def test_df_getitem_col_empty_inputs(input: Any) -> None: + df = pl.DataFrame({"a": [1, 2], "b": [3.0, 4.0]}) + result = df[:, input] + expected = pl.DataFrame() + assert_frame_equal(result, expected) + + +@pytest.mark.parametrize( + ("input", "match"), + [ + ( + [0.0, 1.0], + "cannot select columns using Sequence with elements of type 'float'", + ), + ( + pl.Series([[1, 2], [3, 4]]), + "cannot select columns using Series of type List\\(Int64\\)", + ), + ( + np.array([0.0, 1.0]), + "cannot select columns using NumPy array of type float64", + ), + (object(), "cannot select columns using key of type 'object'"), + ], +) +def test_df_getitem_col_invalid_inputs(input: Any, match: str) -> None: + df = pl.DataFrame({"a": [1, 2], "b": [3.0, 4.0]}) + with pytest.raises(TypeError, match=match): + df[:, input] + + +@pytest.mark.parametrize( + ("input", "match"), + [ + (["a", 2], "'int' object cannot be converted to 'PyString'"), + ([1, "c"], "'str' object cannot be interpreted as an integer"), + ], +) +def test_df_getitem_col_mixed_inputs(input: list[Any], match: str) -> None: + df = pl.DataFrame({"a": [1, 2], "b": [3, 4], "c": [5, 6]}) + with pytest.raises(TypeError, match=match): + df[:, input] + + +@pytest.mark.parametrize( + ("input", "match"), + [ + ([0.0, 1.0], "'float' object cannot be interpreted as an integer"), + ( + pl.Series([[1, 2], [3, 4]]), + "cannot treat Series of type List\\(Int64\\) as indices", + ), + (np.array([0.0, 1.0]), "cannot treat NumPy array of type float64 as indices"), + (object(), "cannot select rows using key of type 'object'"), + ], +) +def test_df_getitem_row_invalid_inputs(input: Any, match: str) -> None: + df = pl.DataFrame({"a": [1, 2], "b": [3.0, 4.0]}) + with pytest.raises(TypeError, match=match): + df[input, :] + + +def test_df_getitem_row_range() -> None: + df = pl.DataFrame({"a": [1, 2, 3, 4], "b": [5.0, 6.0, 7.0, 8.0]}) + result = df[range(3, 0, -2), :] + expected = pl.DataFrame({"a": [4, 2], "b": [8.0, 6.0]}) + assert_frame_equal(result, expected) + + +def test_df_getitem_row_range_single_input() -> None: + df = pl.DataFrame({"a": [1, 2, 3, 4], "b": [5.0, 6.0, 7.0, 8.0]}) + result = df[range(1, 3)] + expected = pl.DataFrame({"a": [2, 3], "b": [6.0, 7.0]}) + assert_frame_equal(result, expected) + + +def test_df_getitem() -> None: + """Test all the methods to use [] on a dataframe.""" + df = pl.DataFrame({"a": [1.0, 2.0, 3.0, 4.0], "b": [3, 4, 5, 6]}) + + # multiple slices. + # The first element refers to the rows, the second element to columns + assert_frame_equal(df[:, :], df) + + # str, always refers to a column name + assert_series_equal(df["a"], pl.Series("a", [1.0, 2.0, 3.0, 4.0])) + + # int, always refers to a row index (zero-based): index=1 => second row + assert_frame_equal(df[1], pl.DataFrame({"a": [2.0], "b": [4]})) + + # int, int. + # The first element refers to the rows, the second element to columns + assert df[2, 1] == 5 + assert df[2, -2] == 3.0 + + with pytest.raises(IndexError): + # Column index out of bounds + df[2, 2] + + with pytest.raises(IndexError): + # Column index out of bounds + df[2, -3] + + # int, list[int]. + # The first element refers to the rows, the second element to columns + assert_frame_equal(df[2, [1, 0]], pl.DataFrame({"b": [5], "a": [3.0]})) + assert_frame_equal(df[2, [-1, -2]], pl.DataFrame({"b": [5], "a": [3.0]})) + + with pytest.raises(IndexError): + # Column index out of bounds + df[2, [2, 0]] + + with pytest.raises(IndexError): + # Column index out of bounds + df[2, [2, -3]] + + # slice. Below an example of taking every second row + assert_frame_equal(df[1::2], pl.DataFrame({"a": [2.0, 4.0], "b": [4, 6]})) + + # slice, empty slice + assert df[:0].columns == ["a", "b"] + assert len(df[:0]) == 0 + + # make mypy happy + empty: list[int] = [] + + # empty list with column selector drops rows but keeps columns + assert_frame_equal(df[empty, :], df[:0]) + + # empty list without column select return empty frame + assert_frame_equal(df[empty], pl.DataFrame({})) + + # numpy array: assumed to be row indices if integers, or columns if strings + + # numpy array: positive idxs and empty idx + for np_dtype in ( + np.int8, + np.int16, + np.int32, + np.int64, + np.uint8, + np.uint16, + np.uint32, + np.uint64, + ): + assert_frame_equal( + df[np.array([1, 0, 3, 2, 3, 0], dtype=np_dtype)], + pl.DataFrame( + {"a": [2.0, 1.0, 4.0, 3.0, 4.0, 1.0], "b": [4, 3, 6, 5, 6, 3]} + ), + ) + assert df[np.array([], dtype=np_dtype)].columns == ["a", "b"] + + # numpy array: positive and negative idxs. + for np_dtype in (np.int8, np.int16, np.int32, np.int64): + assert_frame_equal( + df[np.array([-1, 0, -3, -2, 3, -4], dtype=np_dtype)], + pl.DataFrame( + {"a": [4.0, 1.0, 2.0, 3.0, 4.0, 1.0], "b": [6, 3, 4, 5, 6, 3]} + ), + ) + + # note that we cannot use floats (even if they could be casted to integer without + # loss) + with pytest.raises(TypeError): + _ = df[np.array([1.0])] + + with pytest.raises( + TypeError, match="multi-dimensional NumPy arrays not supported as index" + ): + df[np.array([[0], [1]])] + + # sequences (lists or tuples; tuple only if length != 2) + # if strings or list of expressions, assumed to be column names + # if bools, assumed to be a row mask + # if integers, assumed to be row indices + assert_frame_equal(df[["a", "b"]], df) + assert_frame_equal(df.select([pl.col("a"), pl.col("b")]), df) + assert_frame_equal( + df[[1, -4, -1, 2, 1]], + pl.DataFrame({"a": [2.0, 1.0, 4.0, 3.0, 2.0], "b": [4, 3, 6, 5, 4]}), + ) + + # pl.Series: strings for column selections. + assert_frame_equal(df[pl.Series("", ["a", "b"])], df) + + # pl.Series: positive idxs or empty idxs for row selection. + for pl_dtype in ( + pl.Int8, + pl.Int16, + pl.Int32, + pl.Int64, + pl.UInt8, + pl.UInt16, + pl.UInt32, + pl.UInt64, + ): + assert_frame_equal( + df[pl.Series("", [1, 0, 3, 2, 3, 0], dtype=pl_dtype)], + pl.DataFrame( + {"a": [2.0, 1.0, 4.0, 3.0, 4.0, 1.0], "b": [4, 3, 6, 5, 6, 3]} + ), + ) + assert df[pl.Series("", [], dtype=pl_dtype)].columns == ["a", "b"] + + # pl.Series: positive and negative idxs for row selection. + for pl_dtype in (pl.Int8, pl.Int16, pl.Int32, pl.Int64): + assert_frame_equal( + df[pl.Series("", [-1, 0, -3, -2, 3, -4], dtype=pl_dtype)], + pl.DataFrame( + {"a": [4.0, 1.0, 2.0, 3.0, 4.0, 1.0], "b": [6, 3, 4, 5, 6, 3]} + ), + ) + + # Boolean masks for rows not supported + with pytest.raises(TypeError): + df[[True, False, True], [False, True]] + with pytest.raises(TypeError): + df[pl.Series([True, False, True]), "b"] + + assert_frame_equal(df[np.array([True, False])], df[:, :1]) + + # wrong length boolean mask for column selection + with pytest.raises( + ValueError, + match=f"expected {df.width} values when selecting columns by boolean mask", + ): + df[:, [True, False, True]] + + +def test_df_getitem_5343() -> None: + # https://github.com/pola-rs/polars/issues/5343 + df = pl.DataFrame( + { + f"foo{col}": [n**col for n in range(5)] # 5 rows + for col in range(12) # 12 columns + } + ) + assert df[4, 4] == 256 + assert df[4, 5] == 1024 + assert_frame_equal(df[4, [2]], pl.DataFrame({"foo2": [16]})) + assert_frame_equal(df[4, [5]], pl.DataFrame({"foo5": [1024]})) diff --git a/py-polars/tests/unit/series/test_getitem.py b/py-polars/tests/unit/series/test_getitem.py index 07fb8979f211..3f106de3034f 100644 --- a/py-polars/tests/unit/series/test_getitem.py +++ b/py-polars/tests/unit/series/test_getitem.py @@ -1,6 +1,10 @@ from __future__ import annotations +from typing import Any + import hypothesis.strategies as st +import numpy as np +import pytest from hypothesis import given import polars as pl @@ -28,3 +32,72 @@ def test_series_getitem( assert sliced_py_data == sliced_pl_data, f"slice [{start}:{stop}:{step}] failed" assert_series_equal(srs, srs, check_exact=True) + + +@pytest.mark.parametrize( + ("rng", "expected_values"), + [ + (range(2), [1, 2]), + (range(1, 4), [2, 3, 4]), + (range(3, 0, -2), [4, 2]), + ], +) +def test_series_getitem_range(rng: range, expected_values: list[int]) -> None: + s = pl.Series([1, 2, 3, 4]) + result = s[rng] + expected = pl.Series(expected_values) + assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "mask", + [ + [True, False, True], + pl.Series([True, False, True]), + np.array([True, False, True]), + ], +) +def test_series_getitem_boolean_mask(mask: Any) -> None: + s = pl.Series([1, 2, 3]) + print(mask) + with pytest.raises( + TypeError, + match="selecting rows by passing a boolean mask to `__getitem__` is not supported", + ): + s[mask] + + +@pytest.mark.parametrize( + "input", [[], (), pl.Series(dtype=pl.Int64), np.array([], dtype=np.uint32)] +) +def test_series_getitem_empty_inputs(input: Any) -> None: + s = pl.Series("a", ["x", "y", "z"], dtype=pl.String) + result = s[input] + expected = pl.Series("a", dtype=pl.String) + assert_series_equal(result, expected) + + +@pytest.mark.parametrize("indices", [[0, 2], pl.Series([0, 2]), np.array([0, 2])]) +def test_series_getitem_multiple_indices(indices: Any) -> None: + s = pl.Series(["x", "y", "z"]) + result = s[indices] + expected = pl.Series(["x", "z"]) + assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + ("input", "match"), + [ + ([0.0, 1.0], "'float' object cannot be interpreted as an integer"), + ( + pl.Series([[1, 2], [3, 4]]), + "cannot treat Series of type List\\(Int64\\) as indices", + ), + (np.array([0.0, 1.0]), "cannot treat NumPy array of type float64 as indices"), + (object(), "cannot select elements using key of type 'object'"), + ], +) +def test_df_getitem_col_invalid_inputs(input: Any, match: str) -> None: + s = pl.Series([1, 2, 3]) + with pytest.raises(TypeError, match=match): + s[input] diff --git a/py-polars/tests/unit/test_errors.py b/py-polars/tests/unit/test_errors.py index b06cf74d2abc..029911411eb4 100644 --- a/py-polars/tests/unit/test_errors.py +++ b/py-polars/tests/unit/test_errors.py @@ -170,13 +170,13 @@ def test_getitem_errs() -> None: with pytest.raises( TypeError, - match=r"cannot use `__getitem__` on DataFrame with item {'some'} of type 'set'", + match=r"cannot select columns using key of type 'set': {'some'}", ): df[{"some"}] # type: ignore[call-overload] with pytest.raises( TypeError, - match=r"cannot use `__getitem__` on Series of dtype Int64 with argument {'strange'} of type 'set'", + match=r"cannot select elements using key of type 'set': {'strange'}", ): df["a"][{"strange"}] # type: ignore[call-overload] @@ -533,8 +533,10 @@ def test_window_size_validation() -> None: def test_invalid_getitem_key_err() -> None: df = pl.DataFrame({"x": [1.0], "y": [1.0]}) - with pytest.raises(KeyError, match=r"('x', 'y')"): - df["x", "y"] # type: ignore[index] + with pytest.raises( + TypeError, match="cannot treat Series of type String as indices" + ): + df["x", "y"] def test_invalid_group_by_arg() -> None: