Skip to content

Commit

Permalink
fix(python): Address overly-permissive expand_selectors function
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-beedie committed May 15, 2024
1 parent a34ca2c commit e0a22f4
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 17 deletions.
13 changes: 11 additions & 2 deletions docs/src/python/user-guide/expressions/column-selections.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,13 +78,22 @@
# --8<-- [start:selectors_is_selector_utility]
from polars.selectors import is_selector

out = cs.temporal()
out = cs.numeric()
print(is_selector(out))

out = cs.boolean() | cs.numeric()
print(is_selector(out))

out = cs.numeric() + pl.lit(123)
print(is_selector(out))
# --8<-- [end:selectors_is_selector_utility]

# --8<-- [start:selectors_colnames_utility]
from polars.selectors import expand_selector

out = cs.temporal().as_expr().dt.to_string("%Y-%h-%d")
out = cs.temporal()
print(expand_selector(df, out))

out = ~(cs.temporal() | cs.numeric())
print(expand_selector(df, out))
# --8<-- [end:selectors_colnames_utility]
78 changes: 64 additions & 14 deletions py-polars/polars/selectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,12 @@ def is_selector(obj: Any) -> bool:
True
"""
# note: don't want to expose the "_selector_proxy_" object
return isinstance(obj, _selector_proxy_)
return isinstance(obj, _selector_proxy_) and hasattr(obj, "_attrs")


def expand_selector(
target: DataFrame | LazyFrame | Mapping[str, PolarsDataType], selector: SelectorType
target: DataFrame | LazyFrame | Mapping[str, PolarsDataType],
selector: SelectorType,
) -> tuple[str, ...]:
"""
Expand a selector to column names with respect to a specific frame or schema target.
Expand Down Expand Up @@ -116,6 +117,10 @@ def expand_selector(
>>> cs.expand_selector(schema, cs.float())
('colx', 'coly')
"""
if not is_selector(selector):
msg = f"expected a selector; found {selector!r} instead."
raise TypeError(msg)

if isinstance(target, Mapping):
from polars.dataframe import DataFrame

Expand Down Expand Up @@ -272,9 +277,7 @@ def __invert__(self) -> Self:

def __repr__(self) -> str:
if not hasattr(self, "_attrs"):
return re.sub(
r"<[\w.]+_selector_proxy_[\w ]+>", "<selector>", super().__repr__()
)
return repr(self.as_expr())
elif hasattr(self, "_repr_override"):
return self._repr_override
else:
Expand All @@ -293,7 +296,7 @@ def __repr__(self) -> str:
def __sub__(self, other: Any) -> SelectorType | Expr: # type: ignore[override]
if is_column(other):
other = by_name(other.meta.output_name())
if isinstance(other, _selector_proxy_) and hasattr(other, "_attrs"):
if is_selector(other):
return _selector_proxy_(
self.meta._as_selector().meta._selector_sub(other),
parameters={"self": self, "other": other},
Expand All @@ -305,7 +308,7 @@ def __sub__(self, other: Any) -> SelectorType | Expr: # type: ignore[override]
def __and__(self, other: Any) -> SelectorType | Expr: # type: ignore[override]
if is_column(other):
other = by_name(other.meta.output_name())
if isinstance(other, _selector_proxy_) and hasattr(other, "_attrs"):
if is_selector(other):
return _selector_proxy_(
self.meta._as_selector().meta._selector_and(other),
parameters={"self": self, "other": other},
Expand All @@ -317,7 +320,7 @@ def __and__(self, other: Any) -> SelectorType | Expr: # type: ignore[override]
def __or__(self, other: Any) -> SelectorType | Expr: # type: ignore[override]
if is_column(other):
other = by_name(other.meta.output_name())
if isinstance(other, _selector_proxy_) and hasattr(other, "_attrs"):
if is_selector(other):
return _selector_proxy_(
self.meta._as_selector().meta._selector_add(other),
parameters={"self": self, "other": other},
Expand All @@ -330,7 +333,7 @@ def __rand__(self, other: Any) -> SelectorType | Expr: # type: ignore[override]
# order of operation doesn't matter
if is_column(other):
other = by_name(other.meta.output_name())
if isinstance(other, _selector_proxy_) and hasattr(other, "_attrs"):
if is_selector(other):
return self.__and__(other)
else:
return self.as_expr().__rand__(other)
Expand All @@ -339,17 +342,57 @@ def __ror__(self, other: Any) -> SelectorType | Expr: # type: ignore[override]
# order of operation doesn't matter
if is_column(other):
other = by_name(other.meta.output_name())
if isinstance(other, _selector_proxy_) and hasattr(other, "_attrs"):
if is_selector(other):
return self.__or__(other)
else:
return self.as_expr().__ror__(other)

def as_expr(self) -> Expr:
"""
Materialize the `selector` into a normal expression.
Materialize the `selector` as a normal expression.
This ensures that the operators `|`, `&`, `~` and `-`
are applied on the data and not on the selector sets.
Examples
--------
>>> import polars.selectors as cs
>>> df = pl.DataFrame(
... {
... "colx": ["aa", "bb", "cc"],
... "coly": [True, False, True],
... "colz": [1, 2, 3],
... }
... )
Inverting the boolean selector will choose the non-boolean columns:
>>> df.select(~cs.boolean())
shape: (3, 2)
┌──────┬──────┐
│ colx ┆ colz │
│ --- ┆ --- │
│ str ┆ i64 │
╞══════╪══════╡
│ aa ┆ 1 │
│ bb ┆ 2 │
│ cc ┆ 3 │
└──────┴──────┘
To invert the *values* in the selected boolean columns, we need to
materialize the selector as a standard expression instead:
>>> df.select(~cs.boolean().as_expr())
shape: (3, 1)
┌───────┐
│ coly │
│ --- │
│ bool │
╞═══════╡
│ false │
│ true │
│ false │
└───────┘
"""
return Expr._from_pyexpr(self._pyexpr)

Expand Down Expand Up @@ -2180,24 +2223,31 @@ def time() -> SelectorType:

__all__ = [
"all",
"binary",
"boolean",
"by_dtype",
"by_index",
"by_name",
"categorical",
"contains",
"date",
"datetime",
"decimal",
"duration",
"ends_with",
"exclude",
"expand_selector",
"first",
"float",
"integer",
"is_selector",
"last",
"matches",
"numeric",
"signed_integer",
"starts_with",
"string",
"temporal",
"time",
"string",
"is_selector",
"expand_selector",
"unsigned_integer",
]
8 changes: 7 additions & 1 deletion py-polars/tests/unit/test_selectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,10 +615,16 @@ def test_regex_expansion_exclude_10002() -> None:


def test_is_selector() -> None:
# only actual/compound selectors should pass this check
assert is_selector(cs.numeric())
assert is_selector(cs.by_dtype(pl.UInt32) | pl.col("xyz"))
assert not is_selector(pl.col("cde"))

# expressions (and literals, etc) should fail
assert not is_selector(pl.col("xyz"))
assert not is_selector(cs.numeric().name.suffix(":num"))
assert not is_selector(cs.date() + pl.col("time"))
assert not is_selector(None)
assert not is_selector("x")


def test_selector_or() -> None:
Expand Down

0 comments on commit e0a22f4

Please sign in to comment.